binance_sdk/common/
websocket.rs

1use async_trait::async_trait;
2use flate2::read::ZlibDecoder;
3use futures::{SinkExt, StreamExt, future::try_join_all, stream::FuturesUnordered};
4use http::header::USER_AGENT;
5use serde::de::DeserializeOwned;
6use serde_json::{Value, json};
7use std::{
8    collections::{BTreeMap, HashMap, VecDeque},
9    io::Read,
10    marker::PhantomData,
11    mem::take,
12    sync::{
13        Arc,
14        atomic::{AtomicBool, AtomicUsize, Ordering},
15    },
16    time::Duration,
17};
18use tokio::{
19    net::TcpStream,
20    select, spawn,
21    sync::{
22        Mutex, Notify,
23        mpsc::{Receiver, Sender, UnboundedSender, channel, unbounded_channel},
24        oneshot,
25    },
26    task::JoinHandle,
27    time::{sleep, timeout},
28};
29use tokio_tungstenite::{
30    Connector, MaybeTlsStream, WebSocketStream, connect_async_tls_with_config,
31    tungstenite::{
32        Message,
33        client::IntoClientRequest,
34        protocol::{CloseFrame, WebSocketConfig, frame::coding::CloseCode},
35    },
36};
37use tokio_util::time::DelayQueue;
38use tracing::{debug, error, info, warn};
39
40use super::{
41    config::{AgentConnector, ConfigurationWebsocketApi, ConfigurationWebsocketStreams},
42    errors::{WebsocketConnectionFailureReason, WebsocketError},
43    models::{StreamId, WebsocketApiResponse, WebsocketEvent, WebsocketMode},
44    utils::{build_websocket_api_message, normalize_stream_id, random_string, validate_time_unit},
45};
46
47pub type WebSocketClient = WebSocketStream<MaybeTlsStream<TcpStream>>;
48
49const MAX_CONN_DURATION: Duration = Duration::from_secs(23 * 60 * 60);
50
51pub struct Subscription {
52    handle: JoinHandle<()>,
53}
54
55impl Subscription {
56    /// Cancels the ongoing WebSocket event subscription and stops the event processing task.
57    ///
58    /// This method aborts the background task responsible for receiving and processing
59    /// WebSocket events, effectively unsubscribing from further event notifications.
60    ///
61    /// # Examples
62    ///
63    ///
64    /// let emitter = `WebsocketEventEmitter::new()`;
65    /// let subscription = emitter.subscribe(|event| {
66    ///     // Handle WebSocket event
67    /// });
68    /// `subscription.unsubscribe()`; // Stop receiving events
69    ///
70    pub fn unsubscribe(self) {
71        self.handle.abort();
72    }
73}
74
75#[derive(Clone)]
76pub enum WebsocketBase {
77    WebsocketApi(Arc<WebsocketApi>),
78    WebsocketStreams(Arc<WebsocketStreams>),
79}
80
81pub struct WebsocketEventEmitter {
82    subscribers: Arc<std::sync::Mutex<Vec<UnboundedSender<WebsocketEvent>>>>,
83}
84
85impl Default for WebsocketEventEmitter {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91impl WebsocketEventEmitter {
92    #[must_use]
93    pub fn new() -> Self {
94        Self {
95            subscribers: Arc::new(std::sync::Mutex::new(Vec::new())),
96        }
97    }
98
99    /// Subscribes to WebSocket events and returns a `Subscription` that allows event processing.
100    ///
101    /// This method creates an unbounded channel for receiving WebSocket events and
102    /// spawns an asynchronous task to process these events using the provided callback function.
103    ///
104    /// # Arguments
105    ///
106    /// * `callback` - A mutable function that will be called for each received WebSocket event.
107    ///   The callback must be thread-safe and have a static lifetime.
108    ///
109    /// # Returns
110    ///
111    /// A `Subscription` that can be used to unsubscribe and stop event processing.
112    ///
113    /// # Examples
114    ///
115    ///
116    /// let emitter = `WebsocketEventEmitter::new()`;
117    /// let subscription = emitter.subscribe(|event| {
118    ///     // Handle WebSocket event
119    ///     println!("Received event: {:?}", event);
120    /// });
121    ///
122    /// // Later, when no longer needed
123    /// `subscription.unsubscribe()`;
124    ///
125    pub fn subscribe<F>(&self, mut callback: F) -> Subscription
126    where
127        F: FnMut(WebsocketEvent) + Send + 'static,
128    {
129        let (tx, mut rx) = unbounded_channel();
130        let mut guard = match self.subscribers.lock() {
131            Ok(guard) => guard,
132            Err(poisoned) => poisoned.into_inner(),
133        };
134        guard.push(tx);
135        drop(guard);
136
137        let handle = spawn(async move {
138            while let Some(event) = rx.recv().await {
139                callback(event);
140            }
141        });
142        Subscription { handle }
143    }
144
145    /// Emits a WebSocket event to all registered subscribers.
146    ///
147    /// This method sends the given event to all active subscribers. If a subscriber
148    /// has been dropped without unsubscribing, a warning is logged and the subscriber
149    /// is removed from the list.
150    ///
151    /// # Arguments
152    ///
153    /// * `event` - The WebSocket event to be emitted to all subscribers.
154    pub fn emit(&self, event: &WebsocketEvent) {
155        let mut guard = match self.subscribers.lock() {
156            Ok(guard) => guard,
157            Err(poisoned) => poisoned.into_inner(),
158        };
159
160        guard.retain(|tx| {
161            if tx.send(event.clone()).is_ok() {
162                true
163            } else {
164                warn!("subscriber dropped without unsubscribing");
165                false
166            }
167        });
168    }
169}
170
171/// A trait defining the lifecycle and behavior of a WebSocket connection.
172///
173/// This trait provides methods for handling WebSocket connection events,
174/// including connection opening, message handling, and reconnection URL retrieval.
175///
176/// # Methods
177///
178/// * `on_open`: Called when a WebSocket connection is established
179/// * `on_message`: Called when a message is received over the WebSocket
180/// * `get_reconnect_url`: Determines the URL to use for reconnecting
181///
182/// # Thread Safety
183///
184/// Implementors must be safely shareable across threads, as indicated by the `Send + Sync + 'static` bounds.
185#[async_trait]
186pub trait WebsocketHandler: Send + Sync + 'static {
187    async fn on_open(&self, url: String, connection: Arc<WebsocketConnection>);
188    async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>);
189    async fn get_reconnect_url(
190        &self,
191        default_url: String,
192        connection: Arc<WebsocketConnection>,
193    ) -> String;
194}
195
196pub struct PendingRequest {
197    pub completion: oneshot::Sender<Result<Value, WebsocketError>>,
198}
199
200#[derive(Clone)]
201pub struct WebsocketSessionLogonReq {
202    pub method: String,
203    pub payload: BTreeMap<String, Value>,
204    pub options: WebsocketMessageSendOptions,
205}
206
207pub struct WebsocketConnectionState {
208    pub reconnection_pending: bool,
209    pub renewal_pending: bool,
210    pub close_initiated: bool,
211    pub pending_requests: HashMap<String, PendingRequest>,
212    pub pending_subscriptions: VecDeque<String>,
213    pub stream_callbacks: HashMap<String, Vec<Arc<dyn Fn(&Value) + Send + Sync + 'static>>>,
214    pub is_session_logged_on: bool,
215    pub session_logon_req: Option<WebsocketSessionLogonReq>,
216    pub url_path: Option<String>,
217    pub handler: Option<Arc<dyn WebsocketHandler>>,
218    pub ws_write_tx: Option<UnboundedSender<Message>>,
219}
220
221impl Default for WebsocketConnectionState {
222    fn default() -> Self {
223        Self::new()
224    }
225}
226
227impl WebsocketConnectionState {
228    #[must_use]
229    pub fn new() -> Self {
230        Self {
231            reconnection_pending: false,
232            renewal_pending: false,
233            close_initiated: false,
234            pending_requests: HashMap::new(),
235            pending_subscriptions: VecDeque::new(),
236            stream_callbacks: HashMap::new(),
237            is_session_logged_on: false,
238            session_logon_req: None,
239            url_path: None,
240            handler: None,
241            ws_write_tx: None,
242        }
243    }
244}
245
246pub struct WebsocketConnection {
247    pub id: String,
248    pub drain_notify: Notify,
249    pub state: Mutex<WebsocketConnectionState>,
250}
251
252impl WebsocketConnection {
253    pub fn new(id: impl Into<String>) -> Arc<Self> {
254        Arc::new(Self {
255            id: id.into(),
256            drain_notify: Notify::new(),
257            state: Mutex::new(WebsocketConnectionState::new()),
258        })
259    }
260
261    pub async fn set_handler(&self, handler: Arc<dyn WebsocketHandler>) {
262        let mut conn_state = self.state.lock().await;
263        conn_state.handler = Some(handler);
264    }
265}
266
267struct ReconnectEntry {
268    connection_id: String,
269    url: String,
270    is_renewal: bool,
271}
272
273pub struct WebsocketCommon {
274    pub events: WebsocketEventEmitter,
275    mode: WebsocketMode,
276    round_robin_index: AtomicUsize,
277    connection_pool: Vec<Arc<WebsocketConnection>>,
278    reconnect_tx: Sender<ReconnectEntry>,
279    renewal_tx: Sender<(String, String)>,
280    reconnect_delay: usize,
281    agent: Option<AgentConnector>,
282    user_agent: Option<String>,
283}
284
285impl WebsocketCommon {
286    #[must_use]
287    pub fn new(
288        mut initial_pool: Vec<Arc<WebsocketConnection>>,
289        mode: WebsocketMode,
290        reconnect_delay: usize,
291        agent: Option<AgentConnector>,
292        user_agent: Option<String>,
293    ) -> Arc<Self> {
294        if initial_pool.is_empty() {
295            for _ in 0..mode.pool_size() {
296                let id = random_string();
297                initial_pool.push(WebsocketConnection::new(id));
298            }
299        }
300
301        let (reconnect_tx, reconnect_rx) = channel::<ReconnectEntry>(mode.pool_size());
302        let (renewal_tx, renewal_rx) = channel::<(String, String)>(mode.pool_size());
303
304        let common = Arc::new(Self {
305            events: WebsocketEventEmitter::new(),
306            mode,
307            round_robin_index: AtomicUsize::new(0),
308            connection_pool: initial_pool,
309            reconnect_tx,
310            renewal_tx,
311            reconnect_delay,
312            agent,
313            user_agent,
314        });
315
316        Self::spawn_reconnect_loop(Arc::clone(&common), reconnect_rx);
317        Self::spawn_renewal_loop(&Arc::clone(&common), renewal_rx);
318
319        common
320    }
321
322    /// Spawns an asynchronous loop to handle websocket reconnection attempts
323    ///
324    /// This method manages reconnection logic for websocket connections, including:
325    /// - Scheduling reconnects with a configurable delay
326    /// - Finding the appropriate connection in the connection pool
327    /// - Attempting to reinitialize the connection
328    /// - Logging reconnection failures or warnings
329    ///
330    /// # Arguments
331    /// * `common` - A shared reference to the `WebsocketCommon` instance
332    /// * `reconnect_rx` - A receiver channel for reconnection entries
333    ///
334    /// # Behavior
335    /// - Waits for reconnection entries from the channel
336    /// - Applies a configurable delay before attempting reconnection
337    /// - Attempts to reinitialize the connection with the provided URL
338    /// - Handles and logs any reconnection errors
339    fn spawn_reconnect_loop(common: Arc<Self>, mut reconnect_rx: Receiver<ReconnectEntry>) {
340        spawn(async move {
341            while let Some(entry) = reconnect_rx.recv().await {
342                info!("Scheduling reconnect for id {}", entry.connection_id);
343
344                if !entry.is_renewal {
345                    sleep(Duration::from_millis(common.reconnect_delay as u64)).await;
346                }
347
348                if let Some(conn_arc) = common
349                    .connection_pool
350                    .iter()
351                    .find(|c| c.id == entry.connection_id)
352                    .cloned()
353                {
354                    let common_clone = Arc::clone(&common);
355                    if let Err(err) = common_clone
356                        .init_connect(&entry.url, entry.is_renewal, Some(conn_arc.clone()))
357                        .await
358                    {
359                        error!(
360                            "Reconnect failed for {} → {}: {:?}",
361                            entry.connection_id, entry.url, err
362                        );
363                    }
364
365                    sleep(Duration::from_secs(1)).await;
366                } else {
367                    warn!("No connection {} found for reconnect", entry.connection_id);
368                }
369            }
370        });
371    }
372
373    /// Spawns an asynchronous loop to manage connection renewals
374    ///
375    /// This method handles the periodic renewal of websocket connections by:
376    /// - Maintaining a delay queue for connection expiration
377    /// - Receiving renewal requests for specific connections
378    /// - Triggering reconnection when a connection reaches its maximum duration
379    /// - Attempting to find and renew connections in the connection pool
380    ///
381    /// # Behavior
382    /// - Listens for renewal requests on a channel
383    /// - Tracks connection expiration using a delay queue
384    /// - Initiates reconnection process when a connection expires
385    /// - Handles and logs any renewal failures
386    fn spawn_renewal_loop(common: &Arc<Self>, renewal_rx: Receiver<(String, String)>) {
387        let common = Arc::clone(common);
388        spawn(async move {
389            let mut dq = DelayQueue::new();
390            let mut renewal_rx = renewal_rx;
391
392            loop {
393                select! {
394                    Some((conn_id, url)) = renewal_rx.recv() => {
395                        debug!("Scheduling renewal for {}", conn_id);
396                        dq.insert((conn_id, url), MAX_CONN_DURATION);
397                    }
398
399                    Some(expired) = dq.next() => {
400                        let (conn_id, default_url) = expired.into_inner();
401
402                        if let Some(conn_arc) = common
403                            .connection_pool
404                            .iter()
405                            .find(|c| c.id == conn_id)
406                            .cloned()
407                        {
408                            debug!("Renewing connection {}", conn_id);
409                            let url = common
410                                .get_reconnect_url(&default_url, Arc::clone(&conn_arc))
411                                .await;
412                            if let Err(e) = common.reconnect_tx.send(ReconnectEntry {
413                                connection_id: conn_id.clone(),
414                                url,
415                                is_renewal: true,
416                            }).await {
417                                error!(
418                                    "Failed to enqueue renewal for {}: {:?}",
419                                    conn_id, e
420                                );
421                            }
422                        } else {
423                            warn!("No connection {} found for renewal", conn_id);
424                        }
425                    }
426                }
427            }
428        });
429    }
430
431    /// Checks if a WebSocket connection is ready for use.
432    ///
433    /// # Arguments
434    ///
435    /// * `connection` - The WebSocket connection to check
436    /// * `allow_non_established` - If true, allows connections that are not fully established
437    ///
438    /// # Returns
439    ///
440    /// `true` if the connection is ready, `false` otherwise
441    ///
442    /// # Behavior
443    ///
444    /// A connection is considered ready if:
445    /// - It has a write channel (unless `allow_non_established` is true)
446    /// - No reconnection is pending
447    /// - No close has been initiated
448    pub async fn is_connection_ready(
449        &self,
450        connection: &WebsocketConnection,
451        allow_non_established: bool,
452    ) -> bool {
453        let conn_state = connection.state.lock().await;
454        (allow_non_established || conn_state.ws_write_tx.is_some())
455            && !conn_state.reconnection_pending
456            && !conn_state.close_initiated
457    }
458
459    /// Checks if a WebSocket connection is established.
460    ///
461    /// # Arguments
462    ///
463    /// * `connection` - Optional specific WebSocket connection to check
464    ///
465    /// # Returns
466    ///
467    /// `true` if a connection is ready and established, `false` otherwise
468    ///
469    /// # Behavior
470    ///
471    /// - If a specific connection is provided, checks only that connection
472    /// - If no connection is provided, checks all connections in the pool
473    /// - A connection is considered established if it is ready and not in a non-established state
474    async fn is_connected(&self, connection: Option<&Arc<WebsocketConnection>>) -> bool {
475        if let Some(conn_arc) = connection {
476            return self.is_connection_ready(conn_arc, false).await;
477        }
478
479        for conn_arc in &self.connection_pool {
480            if self.is_connection_ready(conn_arc, false).await {
481                return true;
482            }
483        }
484
485        false
486    }
487
488    /// Retrieves available WebSocket connections from the connection pool.
489    ///
490    /// # Arguments
491    ///
492    /// * `allow_non_established` - If `true`, includes connections that are not fully established
493    /// * `url_path` - Optional URL path to filter connections
494    ///
495    /// # Returns
496    ///
497    /// A vector of `Arc<WebsocketConnection>` that are ready based on the `allow_non_established` flag
498    ///
499    /// # Behavior
500    ///
501    /// - For single connection mode, returns the first connection
502    /// - For multi-connection mode, filters connections based on readiness
503    /// - Uses `is_connection_ready` to determine connection availability
504    async fn get_available_connections(
505        &self,
506        allow_non_established: bool,
507        url_path: Option<&str>,
508    ) -> Vec<Arc<WebsocketConnection>> {
509        if matches!(self.mode, WebsocketMode::Single) && url_path.is_none() {
510            return vec![Arc::clone(&self.connection_pool[0])];
511        }
512
513        let mut ready = Vec::new();
514        for conn in &self.connection_pool {
515            if self.is_connection_ready(conn, allow_non_established).await {
516                ready.push(Arc::clone(conn));
517            }
518        }
519
520        ready
521    }
522
523    /// Retrieves a WebSocket connection from the connection pool.
524    ///
525    /// # Arguments
526    ///
527    /// * `allow_non_established` - If `true`, allows selecting a connection that is not fully established
528    /// * `url_path` - Optional URL path to filter connections
529    ///
530    /// # Returns
531    ///
532    /// An `Arc` to a `WebsocketConnection` from the pool, selected using round-robin strategy
533    ///
534    /// # Errors
535    ///
536    /// Returns `WebsocketError::NotConnected` if no suitable connection is available
537    ///
538    /// # Behavior
539    ///
540    /// - For single connection mode, returns the first connection
541    /// - For multi-connection mode, selects a ready connection using round-robin
542    /// - Filters connections based on `allow_non_established` parameter
543    async fn get_connection(
544        &self,
545        allow_non_established: bool,
546        url_path: Option<&str>,
547    ) -> Result<Arc<WebsocketConnection>, WebsocketError> {
548        let candidates = self
549            .get_available_connections(allow_non_established, url_path)
550            .await;
551
552        let mut ready = Vec::new();
553        for conn in candidates {
554            if let Some(path) = url_path {
555                let st = conn.state.lock().await;
556                if st.url_path.as_deref() != Some(path) {
557                    continue;
558                }
559            }
560            ready.push(conn);
561        }
562
563        if ready.is_empty() {
564            return Err(WebsocketError::NotConnected);
565        }
566
567        let idx = self.round_robin_index.fetch_add(1, Ordering::Relaxed) % ready.len();
568
569        Ok(Arc::clone(&ready[idx]))
570    }
571
572    /// Gracefully closes a WebSocket connection by waiting for pending requests to complete.
573    ///
574    /// # Arguments
575    ///
576    /// * `ws_write_tx_to_close` - Sender channel for sending close message
577    /// * `connection` - Shared reference to the WebSocket connection
578    ///
579    /// # Behavior
580    ///
581    /// - Waits up to 30 seconds for all pending requests to complete
582    /// - Logs debug and warning messages during the closing process
583    /// - Sends a normal close frame to the WebSocket
584    ///
585    /// # Returns
586    ///
587    /// `Ok(())` if connection closes successfully, otherwise a `WebsocketError`
588    async fn close_connection_gracefully(
589        &self,
590        ws_write_tx_to_close: UnboundedSender<Message>,
591        connection: Arc<WebsocketConnection>,
592    ) -> Result<(), WebsocketError> {
593        debug!("Waiting for pending requests to complete before disconnecting.");
594
595        let drain = async {
596            loop {
597                {
598                    let conn_state = connection.state.lock().await;
599                    if conn_state.pending_requests.is_empty() {
600                        debug!("All pending requests completed, proceeding to close.");
601                        break;
602                    }
603                }
604                connection.drain_notify.notified().await;
605            }
606        };
607
608        if timeout(Duration::from_secs(30), drain).await.is_err() {
609            warn!("Timeout waiting for pending requests; forcing close.");
610        }
611
612        info!("Closing WebSocket connection for {}", connection.id);
613        let _ = ws_write_tx_to_close.send(Message::Close(Some(CloseFrame {
614            code: CloseCode::Normal,
615            reason: "".into(),
616        })));
617
618        Ok(())
619    }
620
621    /// Retrieves the URL to use for reconnecting to the WebSocket.
622    ///
623    /// # Arguments
624    ///
625    /// * `default_url` - The default URL to use if no custom reconnect URL is provided
626    /// * `connection` - A shared reference to the WebSocket connection
627    ///
628    /// # Returns
629    ///
630    /// The URL to use for reconnecting, either from a custom handler or the default URL
631    ///
632    /// # Behavior
633    ///
634    /// - Checks if a connection handler is available
635    /// - If a handler exists, calls its `get_reconnect_url` method
636    /// - Otherwise, returns the default URL
637    async fn get_reconnect_url(
638        &self,
639        default_url: &str,
640        connection: Arc<WebsocketConnection>,
641    ) -> String {
642        if let Some(handler) = {
643            let conn_state = connection.state.lock().await;
644            conn_state.handler.clone()
645        } {
646            return handler
647                .get_reconnect_url(default_url.to_string(), Arc::clone(&connection))
648                .await;
649        }
650
651        default_url.to_string()
652    }
653
654    /// Handles the WebSocket connection opening event.
655    ///
656    /// This method is called when a WebSocket connection is successfully established. It performs
657    /// the following key actions:
658    /// - Invokes the connection handler's `on_open` method if a handler is present
659    /// - Logs connection information
660    /// - Handles connection renewal and close scenarios
661    /// - Emits a WebSocket open event
662    ///
663    /// # Arguments
664    ///
665    /// * `url` - The URL of the WebSocket server
666    /// * `connection` - A shared reference to the WebSocket connection
667    /// * `old_ws_writer` - Optional previous WebSocket writer for graceful connection handling
668    ///
669    /// # Behavior
670    ///
671    /// - If a connection handler exists, calls its `on_open` method
672    /// - Checks for pending renewal or close states
673    /// - Closes the previous connection if renewal is in progress
674    /// - Emits an open event if the connection is successfully established
675    async fn on_open(
676        &self,
677        url: String,
678        connection: Arc<WebsocketConnection>,
679        old_ws_writer: Option<UnboundedSender<Message>>,
680    ) {
681        if let Some(handler) = {
682            let conn_state = connection.state.lock().await;
683            conn_state.handler.clone()
684        } {
685            handler.on_open(url.clone(), Arc::clone(&connection)).await;
686        }
687
688        let conn_id = &connection.id;
689        info!("Connected to WebSocket Server with id {}: {}", conn_id, url);
690
691        {
692            let mut conn_state = connection.state.lock().await;
693
694            if conn_state.renewal_pending {
695                conn_state.renewal_pending = false;
696                drop(conn_state);
697                if let Some(tx) = old_ws_writer {
698                    info!("Connection renewal in progress; closing previous connection.");
699                    let _ = self
700                        .close_connection_gracefully(tx, Arc::clone(&connection))
701                        .await;
702                }
703                return;
704            }
705
706            if conn_state.close_initiated {
707                drop(conn_state);
708                if let Some(tx) = connection.state.lock().await.ws_write_tx.clone() {
709                    info!("Close initiated; closing connection.");
710                    let _ = self
711                        .close_connection_gracefully(tx, Arc::clone(&connection))
712                        .await;
713                }
714                return;
715            }
716
717            self.events.emit(&WebsocketEvent::Open);
718        }
719    }
720
721    /// Handles an incoming WebSocket message
722    ///
723    /// # Arguments
724    ///
725    /// * `msg` - The received message as a string
726    /// * `connection` - A shared reference to the WebSocket connection
727    ///
728    /// # Behavior
729    ///
730    /// - If a connection handler exists, spawns an async task to call its `on_message` method
731    /// - Emits a `WebsocketEvent::Message` event with the received message
732    async fn on_message(&self, msg: String, connection: Arc<WebsocketConnection>) {
733        if let Some(handler) = connection.state.lock().await.handler.clone() {
734            let handler_clone = handler.clone();
735            let data = msg.clone();
736            let conn_clone = connection.clone();
737            spawn(async move {
738                handler_clone.on_message(data, conn_clone).await;
739            });
740        }
741        self.events.emit(&WebsocketEvent::Message(msg));
742    }
743
744    /// Creates a WebSocket connection with optional configuration and agent
745    ///
746    /// # Arguments
747    ///
748    /// * `url` - The WebSocket server URL to connect to
749    /// * `agent` - Optional agent connector for configuring the connection
750    /// * `user_agent` - Optional custom user agent string
751    ///
752    /// # Returns
753    ///
754    /// A `Result` containing the established WebSocket stream or a `WebsocketError`
755    ///
756    /// # Errors
757    ///
758    /// Returns a `WebsocketError` if:
759    /// - The WebSocket handshake fails
760    /// - The connection times out after 10 seconds
761    ///
762    /// # Behavior
763    ///
764    /// Attempts to establish a WebSocket connection with a configurable timeout,
765    /// supporting optional TLS, custom user agent, and connection connectors
766    async fn create_websocket(
767        url: &str,
768        agent: Option<AgentConnector>,
769        user_agent: Option<String>,
770    ) -> Result<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>, WebsocketError> {
771        let mut req = url
772            .into_client_request()
773            .map_err(|e| WebsocketError::Handshake(e.to_string()))?;
774
775        if let Some(ua) = user_agent {
776            req.headers_mut().insert(USER_AGENT, ua.parse().unwrap());
777        }
778
779        let ws_config: Option<WebSocketConfig> = None;
780        let disable_nagle = false;
781        let connector: Option<Connector> = agent.map(|dbg| dbg.0);
782
783        let timeout_duration = Duration::from_secs(10);
784        let handshake = connect_async_tls_with_config(req, ws_config, disable_nagle, connector);
785        match timeout(timeout_duration, handshake).await {
786            Ok(Ok((ws_stream, response))) => {
787                debug!("WebSocket connected: {:?}", response);
788                Ok(ws_stream)
789            }
790            Ok(Err(e)) => {
791                let msg = e.to_string();
792                error!("WebSocket handshake failed: {}", msg);
793                Err(WebsocketError::Handshake(msg))
794            }
795            Err(_) => {
796                error!(
797                    "WebSocket connection timed out after {}s",
798                    timeout_duration.as_secs()
799                );
800                Err(WebsocketError::Timeout)
801            }
802        }
803    }
804
805    /// Connects to a WebSocket URL for all connections in the connection pool concurrently
806    ///
807    /// # Arguments
808    ///
809    /// * `url` - The WebSocket server URL to connect to
810    /// * `connections` - Optional specific connections to use, otherwise uses the entire pool
811    ///
812    /// # Returns
813    ///
814    /// A `Result` indicating whether all connections were successfully established
815    ///
816    /// # Errors
817    ///
818    /// Returns a `WebsocketError` if any connection in the pool fails to establish
819    ///
820    /// # Behavior
821    ///
822    /// Attempts to initialize a WebSocket connection for each connection in the pool
823    /// concurrently, logging successes and failures for each connection attempt
824    async fn connect_pool(
825        self: Arc<Self>,
826        url: &str,
827        connections: Option<Vec<Arc<WebsocketConnection>>>,
828    ) -> Result<(), WebsocketError> {
829        let pool: Vec<Arc<WebsocketConnection>> = match connections {
830            Some(v) => v,
831            None => self.connection_pool.clone(),
832        };
833
834        let mut tasks = FuturesUnordered::new();
835
836        for conn in pool {
837            let common = Arc::clone(&self);
838            let url = url.to_owned();
839
840            tasks.push(async move {
841                match common.init_connect(&url, false, Some(conn)).await {
842                    Ok(()) => {
843                        info!("Successfully connected to {}", url);
844                        Ok(())
845                    }
846                    Err(err) => {
847                        error!("Failed to connect to {}: {:?}", url, err);
848                        Err(err)
849                    }
850                }
851            });
852        }
853
854        while let Some(result) = tasks.next().await {
855            result?;
856        }
857
858        Ok(())
859    }
860
861    /// Initializes a WebSocket connection for a specific connection in the pool
862    ///
863    /// # Arguments
864    ///
865    /// * `url` - The WebSocket server URL to connect to
866    /// * `is_renewal` - Flag indicating whether this is a connection renewal attempt
867    /// * `connection` - Optional specific WebSocket connection to use, otherwise selects from the pool
868    ///
869    /// # Returns
870    ///
871    /// A `Result` indicating whether the connection was successfully established
872    ///
873    /// # Errors
874    ///
875    /// Returns a `WebsocketError` if the connection fails to initialize or establish
876    ///
877    /// # Behavior
878    ///
879    /// Handles connection establishment, splitting read/write streams, spawning reader/writer tasks,
880    /// and managing connection state including renewal, reconnection, and error handling
881    async fn init_connect(
882        self: Arc<Self>,
883        url: &str,
884        is_renewal: bool,
885        connection: Option<Arc<WebsocketConnection>>,
886    ) -> Result<(), WebsocketError> {
887        let conn = connection.unwrap_or(self.get_connection(true, None).await?);
888
889        {
890            let mut conn_state = conn.state.lock().await;
891            if conn_state.renewal_pending && is_renewal {
892                info!("Renewal in progress {}→{}", conn.id, url);
893                return Ok(());
894            }
895            if conn_state.ws_write_tx.is_some() && !is_renewal && !conn_state.reconnection_pending {
896                info!("Exists {}; skipping {}", conn.id, url);
897                return Ok(());
898            }
899            if is_renewal {
900                conn_state.renewal_pending = true;
901            }
902
903            conn_state.is_session_logged_on = false;
904        }
905
906        let ws = Self::create_websocket(url, self.agent.clone(), self.user_agent.clone())
907            .await
908            .map_err(|e| {
909                error!("Handshake failed {}: {:?}", url, e);
910                e
911            })?;
912
913        info!("Established {} → {}", conn.id, url);
914
915        if let Err(e) = self.renewal_tx.try_send((conn.id.clone(), url.to_string())) {
916            error!("Failed to schedule renewal for {}: {:?}", conn.id, e);
917        }
918
919        let (write_half, mut read_half) = ws.split();
920        let (tx, mut rx) = unbounded_channel::<Message>();
921
922        let old_writer = {
923            let mut conn_state = conn.state.lock().await;
924            conn_state.reconnection_pending = false;
925            conn_state.ws_write_tx.replace(tx.clone())
926        };
927
928        {
929            let wconn = conn.clone();
930            let common_clone = self.clone();
931            let writer_url = url.to_string();
932
933            spawn(async move {
934                let mut sink = write_half;
935                while let Some(msg) = rx.recv().await {
936                    if let Err(e) = sink.send(msg).await {
937                        let failure_reason =
938                            WebsocketConnectionFailureReason::from_tungstenite_error(&e);
939
940                        error!(
941                            "Write error on {}: {:?}, classified as {:?}",
942                            wconn.id, e, failure_reason
943                        );
944
945                        // Apply same reconnection logic as reader errors
946                        let mut conn_state = wconn.state.lock().await;
947                        if !conn_state.close_initiated
948                            && !is_renewal
949                            && failure_reason.should_reconnect()
950                        {
951                            info!(
952                                "Writer connection {} has recoverable error, attempting reconnection: {:?}",
953                                wconn.id, failure_reason
954                            );
955                            conn_state.reconnection_pending = true;
956                            conn_state.is_session_logged_on = false;
957                            drop(conn_state);
958                            let reconnect_url = common_clone
959                                .get_reconnect_url(&writer_url, Arc::clone(&wconn))
960                                .await;
961
962                            let _ = common_clone
963                                .reconnect_tx
964                                .send(ReconnectEntry {
965                                    connection_id: wconn.id.clone(),
966                                    url: reconnect_url,
967                                    is_renewal: false,
968                                })
969                                .await;
970                        } else {
971                            warn!(
972                                "Writer connection {} has permanent error, will not reconnect: {:?}",
973                                wconn.id, failure_reason
974                            );
975                        }
976
977                        break;
978                    }
979                }
980                debug!("Writer {} exit", wconn.id);
981            });
982        }
983
984        {
985            let common = self.clone();
986            let conn = conn.clone();
987            let url = url.to_string();
988            spawn(async move {
989                common.on_open(url, conn, old_writer).await;
990            });
991        }
992
993        {
994            let common = self.clone();
995            let reader_conn = conn.clone();
996            let read_url = url.to_string();
997
998            spawn(async move {
999                while let Some(item) = read_half.next().await {
1000                    match item {
1001                        Ok(Message::Text(msg)) => {
1002                            common
1003                                .on_message(msg.to_string(), Arc::clone(&reader_conn))
1004                                .await;
1005                        }
1006                        Ok(Message::Binary(bin)) => {
1007                            let mut decoder = ZlibDecoder::new(&bin[..]);
1008                            let mut decompressed = String::new();
1009                            if let Err(err) = decoder.read_to_string(&mut decompressed) {
1010                                error!("Binary message decompress failed: {:?}", err);
1011                                continue;
1012                            }
1013                            common
1014                                .on_message(decompressed, Arc::clone(&reader_conn))
1015                                .await;
1016                        }
1017                        Ok(Message::Ping(payload)) => {
1018                            info!("PING received from server on {}", reader_conn.id);
1019                            common.events.emit(&WebsocketEvent::Ping);
1020                            if let Some(tx) = reader_conn.state.lock().await.ws_write_tx.clone() {
1021                                let _ = tx.send(Message::Pong(payload));
1022                                info!(
1023                                    "Responded PONG to server's PING message on {}",
1024                                    reader_conn.id
1025                                );
1026                            }
1027                        }
1028                        Ok(Message::Pong(_)) => {
1029                            info!("Received PONG from server on {}", reader_conn.id);
1030                            common.events.emit(&WebsocketEvent::Pong);
1031                        }
1032                        Ok(Message::Close(frame)) => {
1033                            let (code, reason) = frame
1034                                .map_or((1000, String::new()), |CloseFrame { code, reason }| {
1035                                    (code.into(), reason.to_string())
1036                                });
1037                            common
1038                                .events
1039                                .emit(&WebsocketEvent::Close(code, reason.clone()));
1040
1041                            // Classify the close reason
1042                            let user_initiated = {
1043                                let conn_state = reader_conn.state.lock().await;
1044                                conn_state.close_initiated
1045                            };
1046
1047                            let failure_reason = WebsocketConnectionFailureReason::from_close_code(
1048                                code,
1049                                user_initiated,
1050                            );
1051
1052                            info!(
1053                                "Connection {} received close frame: code={}, reason='{}', classified as {:?}",
1054                                reader_conn.id, code, reason, failure_reason
1055                            );
1056
1057                            let mut conn_state = reader_conn.state.lock().await;
1058                            if !conn_state.close_initiated
1059                                && !is_renewal
1060                                && failure_reason.should_reconnect()
1061                            {
1062                                info!(
1063                                    "Connection {} received close frame with reconnectable failure: {:?}",
1064                                    reader_conn.id, failure_reason
1065                                );
1066                                conn_state.reconnection_pending = true;
1067                                conn_state.is_session_logged_on = false;
1068                                drop(conn_state);
1069                                let reconnect_url = common
1070                                    .get_reconnect_url(&read_url, Arc::clone(&reader_conn))
1071                                    .await;
1072
1073                                let _ = common
1074                                    .reconnect_tx
1075                                    .send(ReconnectEntry {
1076                                        connection_id: reader_conn.id.clone(),
1077                                        url: reconnect_url,
1078                                        is_renewal: false,
1079                                    })
1080                                    .await;
1081                            } else {
1082                                warn!(
1083                                    "Connection {} received close frame with non-reconnectable failure: {:?}",
1084                                    reader_conn.id, failure_reason
1085                                );
1086
1087                                // Emit detailed error for permanent failures
1088                                common.events.emit(&WebsocketEvent::Error(format!(
1089                                    "[CRITICAL] Connection {} permanently failed: {:?}",
1090                                    reader_conn.id, failure_reason
1091                                )));
1092                            }
1093
1094                            break;
1095                        }
1096                        Err(e) => {
1097                            // Classify the error type for reconnection decision
1098                            let failure_reason =
1099                                WebsocketConnectionFailureReason::from_tungstenite_error(&e);
1100
1101                            error!(
1102                                "WebSocket error on {}: {:?}, classified as {:?}",
1103                                reader_conn.id, e, failure_reason
1104                            );
1105
1106                            common.events.emit(&WebsocketEvent::Error(e.to_string()));
1107
1108                            // Apply the same reconnection logic as Close frames
1109                            let mut conn_state = reader_conn.state.lock().await;
1110                            if !conn_state.close_initiated
1111                                && !is_renewal
1112                                && failure_reason.should_reconnect()
1113                            {
1114                                info!(
1115                                    "Connection {} has recoverable error, attempting reconnection: {:?}",
1116                                    reader_conn.id, failure_reason
1117                                );
1118                                conn_state.reconnection_pending = true;
1119                                conn_state.is_session_logged_on = false;
1120                                drop(conn_state);
1121                                let reconnect_url = common
1122                                    .get_reconnect_url(&read_url, Arc::clone(&reader_conn))
1123                                    .await;
1124
1125                                let _ = common
1126                                    .reconnect_tx
1127                                    .send(ReconnectEntry {
1128                                        connection_id: reader_conn.id.clone(),
1129                                        url: reconnect_url,
1130                                        is_renewal: false,
1131                                    })
1132                                    .await;
1133                            } else {
1134                                warn!(
1135                                    "Connection {} has permanent error, will not reconnect: {:?}",
1136                                    reader_conn.id, failure_reason
1137                                );
1138
1139                                // Emit critical error for non-reconnectable failures
1140                                common.events.emit(&WebsocketEvent::Error(format!(
1141                                    "[CRITICAL] Connection {} permanently failed: {:?}",
1142                                    reader_conn.id, failure_reason
1143                                )));
1144                            }
1145
1146                            break;
1147                        }
1148                        _ => {}
1149                    }
1150                }
1151
1152                // Handle case where stream ends unexpectedly (e.g., network disconnection)
1153                info!("WebSocket stream ended for connection {}", reader_conn.id);
1154
1155                // Handle unexpected stream end with same logic as other errors
1156                let failure_reason = WebsocketConnectionFailureReason::StreamEnded;
1157
1158                info!(
1159                    "WebSocket stream ended for connection {}, classified as {:?}",
1160                    reader_conn.id, failure_reason
1161                );
1162
1163                let mut conn_state = reader_conn.state.lock().await;
1164                if !conn_state.close_initiated && !is_renewal && failure_reason.should_reconnect() {
1165                    info!(
1166                        "Connection {} stream ended unexpectedly, attempting reconnection",
1167                        reader_conn.id
1168                    );
1169                    conn_state.reconnection_pending = true;
1170                    conn_state.is_session_logged_on = false;
1171                    conn_state.ws_write_tx = None;
1172                    drop(conn_state);
1173                    let reconnect_url = common
1174                        .get_reconnect_url(&read_url, Arc::clone(&reader_conn))
1175                        .await;
1176
1177                    let _ = common
1178                        .reconnect_tx
1179                        .send(ReconnectEntry {
1180                            connection_id: reader_conn.id.clone(),
1181                            url: reconnect_url,
1182                            is_renewal: false,
1183                        })
1184                        .await;
1185                } else {
1186                    debug!(
1187                        "Connection {} stream ended normally (close_initiated={}, is_renewal={})",
1188                        reader_conn.id, conn_state.close_initiated, is_renewal
1189                    );
1190                }
1191
1192                debug!("Reader actor for {} exiting", reader_conn.id);
1193            });
1194        }
1195
1196        Ok(())
1197    }
1198    /// Gracefully disconnects all active WebSocket connections.
1199    ///
1200    /// This method attempts to close all connections in the connection pool within a 30-second timeout.
1201    /// It marks each connection as close-initiated and attempts to close them gracefully.
1202    ///
1203    /// # Returns
1204    ///
1205    /// - `Ok(())` if all connections are successfully closed
1206    /// - `Err(WebsocketError)` if there are errors during disconnection or a timeout occurs
1207    ///
1208    /// # Errors
1209    ///
1210    /// Returns `WebsocketError::Timeout` if disconnection takes longer than 30 seconds
1211    ///
1212    async fn disconnect(&self) -> Result<(), WebsocketError> {
1213        if !self.is_connected(None).await {
1214            warn!("No active connection to close.");
1215            return Ok(());
1216        }
1217
1218        let mut shutdowns = FuturesUnordered::new();
1219        for conn in &self.connection_pool {
1220            {
1221                let mut conn_state = conn.state.lock().await;
1222                conn_state.close_initiated = true;
1223                if let Some(tx) = &conn_state.ws_write_tx {
1224                    shutdowns.push(self.close_connection_gracefully(tx.clone(), Arc::clone(conn)));
1225                }
1226            }
1227        }
1228
1229        let close_all = async {
1230            while let Some(result) = shutdowns.next().await {
1231                result?;
1232            }
1233            Ok::<(), WebsocketError>(())
1234        };
1235
1236        match timeout(Duration::from_secs(30), close_all).await {
1237            Ok(Ok(())) => {
1238                info!("Disconnected all WebSocket connections successfully.");
1239                for conn in &self.connection_pool {
1240                    let mut st = conn.state.lock().await;
1241                    st.is_session_logged_on = false;
1242                    st.session_logon_req = None;
1243                }
1244                Ok(())
1245            }
1246            Ok(Err(err)) => {
1247                error!("Error while disconnecting: {:?}", err);
1248                Err(err)
1249            }
1250            Err(_) => {
1251                error!("Timed out while disconnecting WebSocket connections.");
1252                Err(WebsocketError::Timeout)
1253            }
1254        }
1255    }
1256
1257    /// Sends a PING message to all ready WebSocket connections.
1258    ///
1259    /// This method iterates through the connection pool, identifies ready connections,
1260    /// and sends a PING message to each of them. It logs the number of connections
1261    /// being pinged and handles any send errors individually.
1262    ///
1263    /// # Behavior
1264    ///
1265    /// - Skips connections that are not ready
1266    /// - Logs a warning if no connections are ready
1267    /// - Sends PING messages concurrently
1268    /// - Logs debug/error messages for each PING attempt
1269    async fn ping_server(&self) {
1270        let mut ready = Vec::new();
1271        for conn in &self.connection_pool {
1272            if self.is_connection_ready(conn, false).await {
1273                let id = conn.id.clone();
1274                let ws_write_tx = {
1275                    let conn_state = conn.state.lock().await;
1276                    conn_state.ws_write_tx.clone()
1277                };
1278                ready.push((id, ws_write_tx));
1279            }
1280        }
1281
1282        if ready.is_empty() {
1283            warn!("No ready connections for PING.");
1284            return;
1285        }
1286        info!("Sending PING to {} WebSocket connections.", ready.len());
1287
1288        let mut tasks = FuturesUnordered::new();
1289        for (id, ws_write_tx_opt) in ready {
1290            if let Some(tx) = ws_write_tx_opt {
1291                tasks.push(async move {
1292                    if let Err(e) = tx.send(Message::Ping(Vec::new().into())) {
1293                        error!("Failed to send PING to {}: {:?}", id, e);
1294                    } else {
1295                        debug!("Sent PING to connection {}", id);
1296                    }
1297                });
1298            } else {
1299                error!("Connection {} was ready but has no write channel", id);
1300            }
1301        }
1302
1303        while tasks.next().await.is_some() {}
1304    }
1305
1306    /// Sends a WebSocket message and optionally waits for a reply.
1307    ///
1308    /// # Arguments
1309    ///
1310    /// * `payload` - The message payload to send
1311    /// * `id` - Optional request identifier, required when waiting for a reply
1312    /// * `wait_for_reply` - Whether to wait for a response to the message
1313    /// * `timeout` - Maximum duration to wait for a reply
1314    /// * `connection` - Optional specific WebSocket connection to use
1315    ///
1316    /// # Returns
1317    ///
1318    /// A receiver for the response if `wait_for_reply` is true, otherwise `None`
1319    ///
1320    /// # Errors
1321    ///
1322    /// Returns a `WebsocketError` if the connection is not ready or the send fails
1323    async fn send(
1324        &self,
1325        payload: String,
1326        id: Option<String>,
1327        wait_for_reply: bool,
1328        timeout: Duration,
1329        connection: Option<Arc<WebsocketConnection>>,
1330    ) -> Result<Option<oneshot::Receiver<Result<Value, WebsocketError>>>, WebsocketError> {
1331        let conn = if let Some(c) = connection {
1332            c
1333        } else {
1334            self.get_connection(false, None).await?
1335        };
1336
1337        if !self.is_connected(Some(&conn)).await {
1338            warn!("Send attempted on a non-connected socket");
1339            return Err(WebsocketError::NotConnected);
1340        }
1341
1342        let ws_write_tx = {
1343            let conn_state = conn.state.lock().await;
1344            conn_state
1345                .ws_write_tx
1346                .clone()
1347                .ok_or(WebsocketError::NotConnected)?
1348        };
1349
1350        debug!("Sending message to WebSocket on connection {}", conn.id);
1351
1352        ws_write_tx
1353            .send(Message::Text(payload.clone().into()))
1354            .map_err(|_| WebsocketError::NotConnected)?;
1355
1356        if !wait_for_reply {
1357            return Ok(None);
1358        }
1359
1360        let request_id = id.ok_or_else(|| {
1361            error!("id is required when waiting for a reply");
1362            WebsocketError::NotConnected
1363        })?;
1364
1365        let (tx, rx) = oneshot::channel();
1366        {
1367            let mut conn_state = conn.state.lock().await;
1368            conn_state
1369                .pending_requests
1370                .insert(request_id.clone(), PendingRequest { completion: tx });
1371        }
1372
1373        let conn_clone = Arc::clone(&conn);
1374        spawn(async move {
1375            sleep(timeout).await;
1376            let mut conn_state = conn_clone.state.lock().await;
1377            if let Some(pending_req) = conn_state.pending_requests.remove(&request_id) {
1378                let _ = pending_req.completion.send(Err(WebsocketError::Timeout));
1379            }
1380        });
1381
1382        Ok(Some(rx))
1383    }
1384}
1385
1386#[derive(Debug, Default, Clone)]
1387pub struct WebsocketMessageSendOptions {
1388    pub with_api_key: bool,
1389    pub is_signed: bool,
1390    pub is_session_logon: Option<bool>,
1391    pub is_session_logout: Option<bool>,
1392}
1393
1394impl WebsocketMessageSendOptions {
1395    #[must_use]
1396    pub fn new() -> Self {
1397        Self::default()
1398    }
1399
1400    #[must_use]
1401    pub fn with_api_key(mut self) -> Self {
1402        self.with_api_key = true;
1403        self
1404    }
1405
1406    #[must_use]
1407    pub fn signed(mut self) -> Self {
1408        self.is_signed = true;
1409        self
1410    }
1411
1412    #[must_use]
1413    pub fn session_logon(mut self) -> Self {
1414        self.is_session_logon = Some(true);
1415        self
1416    }
1417
1418    #[must_use]
1419    pub fn session_logout(mut self) -> Self {
1420        self.is_session_logout = Some(true);
1421        self
1422    }
1423}
1424
1425#[derive(Debug)]
1426pub enum SendWebsocketMessageResult<R> {
1427    Single(WebsocketApiResponse<R>),
1428    Multiple(Vec<WebsocketApiResponse<R>>),
1429}
1430
1431impl<R> IntoIterator for SendWebsocketMessageResult<R> {
1432    type Item = WebsocketApiResponse<R>;
1433    type IntoIter = std::vec::IntoIter<Self::Item>;
1434
1435    fn into_iter(self) -> Self::IntoIter {
1436        match self {
1437            SendWebsocketMessageResult::Single(resp) => vec![resp].into_iter(),
1438            SendWebsocketMessageResult::Multiple(v) => v.into_iter(),
1439        }
1440    }
1441}
1442
1443pub struct WebsocketApi {
1444    pub common: Arc<WebsocketCommon>,
1445    configuration: ConfigurationWebsocketApi,
1446    is_connecting: Arc<Mutex<bool>>,
1447    stream_callbacks: Mutex<HashMap<String, Vec<Arc<dyn Fn(&Value) + Send + Sync + 'static>>>>,
1448}
1449
1450impl WebsocketApi {
1451    #[must_use]
1452    /// Creates a new WebSocket API instance with the given configuration and connection pool.
1453    ///
1454    /// # Arguments
1455    ///
1456    /// * `configuration` - Configuration settings for the WebSocket API
1457    /// * `connection_pool` - A vector of WebSocket connections to be used
1458    ///
1459    /// # Returns
1460    ///
1461    /// An `Arc`-wrapped `WebsocketApi` instance ready for use
1462    ///
1463    /// # Panics
1464    ///
1465    /// This function will panic if the configuration is not valid.
1466    ///
1467    /// # Examples
1468    ///
1469    ///
1470    /// let api = `WebsocketApi::new(config`, `connection_pool`);
1471    ///
1472    pub fn new(
1473        configuration: ConfigurationWebsocketApi,
1474        connection_pool: Vec<Arc<WebsocketConnection>>,
1475    ) -> Arc<Self> {
1476        let agent_clone = configuration.agent.clone();
1477        let user_agent_clone = configuration.user_agent.clone();
1478        let common = WebsocketCommon::new(
1479            connection_pool,
1480            configuration.mode.clone(),
1481            usize::try_from(configuration.reconnect_delay)
1482                .expect("reconnect_delay should fit in usize"),
1483            agent_clone,
1484            Some(user_agent_clone),
1485        );
1486
1487        Arc::new(Self {
1488            common: Arc::clone(&common),
1489            configuration,
1490            is_connecting: Arc::new(Mutex::new(false)),
1491            stream_callbacks: Mutex::new(HashMap::new()),
1492        })
1493    }
1494
1495    /// Connects to a WebSocket server with a configurable timeout and connection handling.
1496    ///
1497    /// This method attempts to establish a WebSocket connection if not already connected.
1498    /// It prevents multiple simultaneous connection attempts and supports a connection pool.
1499    ///
1500    /// # Errors
1501    ///
1502    /// Returns a `WebsocketError` if:
1503    /// - Connection fails
1504    /// - Connection times out after 10 seconds
1505    ///
1506    /// # Behavior
1507    ///
1508    /// - Checks if already connected and returns early if so
1509    /// - Prevents multiple concurrent connection attempts
1510    /// - Sets a WebSocket handler for the connection pool
1511    /// - Attempts to connect with a 10-second timeout
1512    ///
1513    /// # Returns
1514    ///
1515    /// `Ok(())` if connection is successful, otherwise a `WebsocketError`
1516    pub async fn connect(self: Arc<Self>) -> Result<(), WebsocketError> {
1517        if self.common.is_connected(None).await {
1518            info!("WebSocket connection already established");
1519            return Ok(());
1520        }
1521
1522        {
1523            let mut flag = self.is_connecting.lock().await;
1524            if *flag {
1525                info!("Already connecting...");
1526                return Ok(());
1527            }
1528            *flag = true;
1529        }
1530
1531        let url = self.prepare_url(self.configuration.ws_url.as_deref().unwrap_or_default());
1532
1533        let handler: Arc<dyn WebsocketHandler> = self.clone();
1534        for slot in &self.common.connection_pool {
1535            slot.set_handler(handler.clone()).await;
1536        }
1537
1538        let result = select! {
1539            () = sleep(Duration::from_secs(10)) => Err(WebsocketError::Timeout),
1540            r = self.common.clone().connect_pool(&url, None) => r,
1541        };
1542
1543        {
1544            let mut flag = self.is_connecting.lock().await;
1545            *flag = false;
1546        }
1547
1548        result
1549    }
1550
1551    /// Disconnects the WebSocket connection.
1552    ///
1553    /// # Returns
1554    ///
1555    /// `Ok(())` if disconnection is successful, otherwise a `WebsocketError`
1556    ///
1557    /// # Errors
1558    ///
1559    /// Returns a `WebsocketError` if:
1560    /// - Disconnection fails
1561    /// - Connection is not established
1562    ///
1563    pub async fn disconnect(&self) -> Result<(), WebsocketError> {
1564        self.common.disconnect().await
1565    }
1566
1567    /// Checks if the WebSocket connection is currently established.
1568    ///
1569    /// # Returns
1570    ///
1571    /// `true` if the connection is active, `false` otherwise.
1572    pub async fn is_connected(&self) -> bool {
1573        self.common.is_connected(None).await
1574    }
1575
1576    /// Sends a ping to the WebSocket server to maintain the connection.
1577    ///
1578    /// This method calls the underlying connection's ping mechanism to check
1579    /// and keep the WebSocket connection alive.
1580    pub async fn ping_server(&self) {
1581        self.common.ping_server().await;
1582    }
1583
1584    /// Sends a WebSocket message with the specified method and payload.
1585    ///
1586    /// This method prepares and sends a WebSocket request with optional API key and signature.
1587    /// It handles connection status, generates a unique request ID, and processes the response.
1588    ///
1589    /// # Arguments
1590    ///
1591    /// * `method` - The WebSocket API method to be called
1592    /// * `payload` - A map of parameters to be sent with the request
1593    /// * `options` - Configuration options for message sending (API key, signing)
1594    ///
1595    /// # Returns
1596    ///
1597    /// A deserialized response of type `R` or a `WebsocketError` if the request fails
1598    ///
1599    /// # Panics
1600    ///
1601    /// Panics if:
1602    ///
1603    /// - The WebSocket is not connected
1604    /// - The request cannot be processed
1605    /// - No response is received within the timeout
1606    ///
1607    /// # Errors
1608    ///
1609    /// Returns `WebsocketError` if:
1610    /// - The WebSocket is not connected
1611    /// - The request cannot be processed
1612    /// - No response is received within the timeout
1613    pub async fn send_message<R>(
1614        &self,
1615        method: &str,
1616        payload: BTreeMap<String, Value>,
1617        options: WebsocketMessageSendOptions,
1618    ) -> Result<SendWebsocketMessageResult<R>, WebsocketError>
1619    where
1620        R: DeserializeOwned + Send + Sync + 'static,
1621    {
1622        if !self.common.is_connected(None).await {
1623            return Err(WebsocketError::NotConnected);
1624        }
1625
1626        let do_multi =
1627            options.is_session_logon.unwrap_or(false) || options.is_session_logout.unwrap_or(false);
1628
1629        let connections = if do_multi {
1630            self.common.get_available_connections(false, None).await
1631        } else {
1632            vec![self.common.get_connection(false, None).await?]
1633        };
1634
1635        let skip_auth = if do_multi {
1636            false
1637        } else {
1638            let connection = &connections[0];
1639            let conn_state = connection.state.lock().await;
1640            self.configuration.auto_session_relogon && conn_state.is_session_logged_on
1641        };
1642
1643        let payload_clone = payload.clone();
1644
1645        let (id, request) =
1646            build_websocket_api_message(&self.configuration, method, payload, &options, skip_auth);
1647        let raw_payload = serde_json::to_string(&request).unwrap();
1648        debug!("Sending message to WebSocket API: {:?}", request);
1649
1650        let timeout = Duration::from_millis(self.configuration.timeout);
1651
1652        let mut receivers = Vec::with_capacity(connections.len());
1653        for connection in &connections {
1654            let opt_rx = self
1655                .common
1656                .send(
1657                    raw_payload.clone(),
1658                    Some(id.clone()),
1659                    true,
1660                    timeout,
1661                    Some(connection.clone()),
1662                )
1663                .await?;
1664            receivers.push((connection.clone(), opt_rx));
1665        }
1666
1667        let mut raw_msgs = Vec::with_capacity(receivers.len());
1668        for (_conn, opt_rx) in receivers {
1669            let rx = opt_rx.ok_or(WebsocketError::NoResponse)?;
1670            let msg = rx.await.unwrap_or(Err(WebsocketError::Timeout))?;
1671            raw_msgs.push(msg);
1672        }
1673
1674        let mut responses = Vec::with_capacity(raw_msgs.len());
1675        for msg in raw_msgs {
1676            let raw = msg
1677                .get("result")
1678                .or_else(|| msg.get("response"))
1679                .cloned()
1680                .unwrap_or(Value::Null);
1681
1682            let rate_limits = msg
1683                .get("rateLimits")
1684                .and_then(Value::as_array)
1685                .map(|arr| {
1686                    arr.iter()
1687                        .filter_map(|v| serde_json::from_value(v.clone()).ok())
1688                        .collect()
1689                })
1690                .unwrap_or_default();
1691
1692            responses.push(WebsocketApiResponse {
1693                raw,
1694                rate_limits,
1695                _marker: PhantomData,
1696            });
1697        }
1698
1699        if do_multi && self.configuration.auto_session_relogon {
1700            for connection in &connections {
1701                let mut state = connection.state.lock().await;
1702                if options.is_session_logon.unwrap_or(false) {
1703                    state.is_session_logged_on = true;
1704                    state.session_logon_req = Some(WebsocketSessionLogonReq {
1705                        method: method.to_string(),
1706                        payload: payload_clone.clone(),
1707                        options: options.clone(),
1708                    });
1709                } else {
1710                    state.is_session_logged_on = false;
1711                    state.session_logon_req = None;
1712                }
1713            }
1714        }
1715
1716        Ok(if responses.len() == 1 && !do_multi {
1717            SendWebsocketMessageResult::Single(responses.into_iter().next().unwrap())
1718        } else {
1719            SendWebsocketMessageResult::Multiple(responses)
1720        })
1721    }
1722
1723    /// Prepares a WebSocket URL by appending a validated time unit parameter.
1724    ///
1725    /// This method checks if a time unit is configured and validates it. If valid,
1726    /// the time unit is appended to the URL as a query parameter. If no time unit
1727    /// is specified or the validation fails, the original URL is returned.
1728    ///
1729    /// # Arguments
1730    ///
1731    /// * `ws_url` - The base WebSocket URL to be modified
1732    ///
1733    /// # Returns
1734    ///
1735    /// A modified URL with the time unit parameter, or the original URL if no
1736    /// modification is possible
1737    fn prepare_url(&self, ws_url: &str) -> String {
1738        let mut url = ws_url.to_string();
1739
1740        let time_unit = match &self.configuration.time_unit {
1741            Some(u) => u.to_string(),
1742            None => return url,
1743        };
1744
1745        match validate_time_unit(&time_unit) {
1746            Ok(Some(validated)) => {
1747                let sep = if url.contains('?') { '&' } else { '?' };
1748                url.push(sep);
1749                url.push_str("timeUnit=");
1750                url.push_str(validated);
1751            }
1752            Ok(None) => {}
1753            Err(e) => {
1754                error!("Invalid time unit provided: {:?}", e);
1755            }
1756        }
1757
1758        url
1759    }
1760}
1761
1762#[async_trait]
1763impl WebsocketHandler for WebsocketApi {
1764    /// Handles the WebSocket connection opening event, attempting to re-establish a session logon if needed.
1765    ///
1766    /// This method checks if a session logon request exists and has not already been logged on.
1767    /// If conditions are met, it attempts to send a session re-logon message and update the connection state.
1768    ///
1769    /// # Arguments
1770    ///
1771    /// * `_url` - The WebSocket connection URL (unused)
1772    /// * `connection` - The WebSocket connection context
1773    ///
1774    /// # Behavior
1775    ///
1776    /// - Checks for an existing session logon request
1777    /// - Verifies the session is not already logged on
1778    /// - Attempts to send a re-logon message
1779    /// - Updates connection state upon successful re-logon
1780    /// - Logs errors if re-logon dispatch fails
1781    async fn on_open(&self, _url: String, connection: Arc<WebsocketConnection>) {
1782        let session_req = {
1783            let conn_state = connection.state.lock().await;
1784            conn_state.session_logon_req.clone()
1785        };
1786
1787        let Some(req) = session_req else {
1788            return;
1789        };
1790
1791        let already_logged_on = {
1792            let conn_state = connection.state.lock().await;
1793            conn_state.is_session_logged_on
1794        };
1795
1796        if already_logged_on {
1797            debug!(
1798                "Connection {} already logged on, skipping re-logon",
1799                connection.id
1800            );
1801            return;
1802        }
1803
1804        let conn = connection.clone();
1805        let common = Arc::clone(&self.common);
1806        let configuration = self.configuration.clone();
1807        let method = req.method.clone();
1808        let payload = req.payload.clone();
1809        let options = req.options.clone();
1810
1811        spawn(async move {
1812            let (id, json_msg) =
1813                build_websocket_api_message(&configuration, &method, payload, &options, false);
1814
1815            let raw_message = match serde_json::to_string(&json_msg) {
1816                Ok(msg) => msg,
1817                Err(e) => {
1818                    warn!(
1819                        "Failed to serialize session logon message for connection {}: {}",
1820                        conn.id, e
1821                    );
1822                    return;
1823                }
1824            };
1825
1826            debug!(
1827                "Session re-logon on connection {}: {}",
1828                conn.id, raw_message
1829            );
1830
1831            let rx = match common
1832                .send(
1833                    raw_message,
1834                    Some(id.clone()),
1835                    true,
1836                    Duration::from_millis(configuration.timeout),
1837                    Some(conn.clone()),
1838                )
1839                .await
1840            {
1841                Ok(Some(rx)) => rx,
1842                Ok(None) => {
1843                    warn!(
1844                        "Session re-logon dispatch returned None for connection {}",
1845                        conn.id
1846                    );
1847                    return;
1848                }
1849                Err(e) => {
1850                    warn!(
1851                        "Session re-logon dispatch failed on connection {}: {}",
1852                        conn.id, e
1853                    );
1854                    return;
1855                }
1856            };
1857
1858            let Ok(result) = timeout(Duration::from_millis(configuration.timeout), rx).await else {
1859                warn!("Session re-logon timed out on connection {}", conn.id);
1860                return;
1861            };
1862
1863            let final_result = match result {
1864                Ok(final_result) => final_result,
1865                Err(e) => {
1866                    warn!(
1867                        "Session re-logon receiver error on connection {}: {}",
1868                        conn.id, e
1869                    );
1870                    return;
1871                }
1872            };
1873
1874            let payload = match final_result {
1875                Ok(payload) => payload,
1876                Err(e) => {
1877                    warn!(
1878                        "Session re-logon payload error on connection {}: {}",
1879                        conn.id, e
1880                    );
1881                    return;
1882                }
1883            };
1884
1885            debug!(
1886                "Session re-logon succeeded on connection {}: {}",
1887                conn.id, payload
1888            );
1889            let mut conn_state = conn.state.lock().await;
1890            conn_state.is_session_logged_on = true;
1891        });
1892    }
1893
1894    /// Handles incoming WebSocket messages by parsing the JSON payload and processing pending requests.
1895    ///
1896    /// This method is responsible for:
1897    /// - Parsing the received WebSocket message as JSON
1898    /// - Matching the message to a pending request by its ID
1899    /// - Sending the response back to the original request's completion channel
1900    /// - Handling both successful and error responses
1901    ///
1902    /// # Arguments
1903    ///
1904    /// * `data` - The raw WebSocket message as a string
1905    /// * `connection` - The WebSocket connection context associated with the message
1906    ///
1907    /// # Behavior
1908    ///
1909    /// - If message parsing fails, logs an error and returns
1910    /// - For known request IDs, sends the response to the corresponding completion channel
1911    /// - Warns about responses for unknown or timed-out requests
1912    /// - Differentiates between successful (status < 400) and error responses
1913    async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>) {
1914        let msg: Value = match serde_json::from_str(&data) {
1915            Ok(v) => v,
1916            Err(err) => {
1917                error!("Failed to parse WebSocket message {} – {}", data, err);
1918                return;
1919            }
1920        };
1921
1922        if let Some(id) = msg.get("id").and_then(Value::as_str) {
1923            let maybe_sender = {
1924                let mut conn_state = connection.state.lock().await;
1925                conn_state.pending_requests.remove(id)
1926            };
1927
1928            if let Some(PendingRequest { completion }) = maybe_sender {
1929                connection.drain_notify.notify_one();
1930                let status = msg.get("status").and_then(Value::as_u64).unwrap_or(200);
1931                if status >= 400 {
1932                    let error_map = msg
1933                        .get("error")
1934                        .and_then(Value::as_object)
1935                        .unwrap_or(&serde_json::Map::new())
1936                        .clone();
1937
1938                    let code = error_map
1939                        .get("code")
1940                        .and_then(Value::as_i64)
1941                        .unwrap_or(status.try_into().unwrap());
1942
1943                    let message = error_map
1944                        .get("msg")
1945                        .and_then(Value::as_str)
1946                        .unwrap_or("Unknown error")
1947                        .to_string();
1948
1949                    let _ = completion.send(Err(WebsocketError::ResponseError { code, message }));
1950                } else {
1951                    let _ = completion.send(Ok(msg.clone()));
1952                }
1953            }
1954
1955            return;
1956        }
1957
1958        if let Some(event) = msg.get("event") {
1959            if event.get("e").is_some() {
1960                for callbacks in self.stream_callbacks.lock().await.values() {
1961                    for callback in callbacks {
1962                        callback(event);
1963                    }
1964                }
1965
1966                return;
1967            }
1968        }
1969
1970        warn!(
1971            "Received response for unknown or timed-out request: {}",
1972            data
1973        );
1974    }
1975
1976    /// Generates the URL to use for reconnecting to a WebSocket connection.
1977    ///
1978    /// # Arguments
1979    ///
1980    /// * `default_url` - The original URL to potentially modify for reconnection
1981    /// * `_connection` - The WebSocket connection context (currently unused)
1982    ///
1983    /// # Returns
1984    ///
1985    /// A `String` representing the URL to use for reconnecting
1986    async fn get_reconnect_url(
1987        &self,
1988        default_url: String,
1989        _connection: Arc<WebsocketConnection>,
1990    ) -> String {
1991        default_url
1992    }
1993}
1994
1995pub struct WebsocketStreams {
1996    pub common: Arc<WebsocketCommon>,
1997    pub stream_id_is_strictly_number: AtomicBool,
1998    url_paths: Vec<String>,
1999    is_connecting: Mutex<bool>,
2000    connection_streams: Mutex<HashMap<String, Arc<WebsocketConnection>>>,
2001    configuration: ConfigurationWebsocketStreams,
2002}
2003
2004impl WebsocketStreams {
2005    /// Creates a new `WebsocketStreams` instance with the given configuration and connection pool.
2006    ///
2007    /// # Arguments
2008    ///
2009    /// * `configuration` - Configuration settings for the WebSocket streams
2010    /// * `connection_pool` - A vector of WebSocket connections to use
2011    /// * `url_paths` - A vector of URL paths for the streams
2012    ///
2013    /// # Returns
2014    ///
2015    /// An `Arc`-wrapped `WebsocketStreams` instance
2016    ///
2017    /// # Panics
2018    ///
2019    /// Panics if the `reconnect_delay` cannot be converted to `usize`
2020    #[must_use]
2021    pub fn new(
2022        configuration: ConfigurationWebsocketStreams,
2023        mut connection_pool: Vec<Arc<WebsocketConnection>>,
2024        url_paths: Vec<String>,
2025    ) -> Arc<Self> {
2026        if !url_paths.is_empty() {
2027            let base_pool_size = configuration.mode.pool_size();
2028            let expected = base_pool_size * url_paths.len();
2029
2030            while connection_pool.len() < expected {
2031                connection_pool.push(WebsocketConnection::new(random_string()));
2032            }
2033        }
2034
2035        let agent_clone = configuration.agent.clone();
2036        let user_agent_clone = configuration.user_agent.clone();
2037        let common = WebsocketCommon::new(
2038            connection_pool,
2039            configuration.mode.clone(),
2040            usize::try_from(configuration.reconnect_delay)
2041                .expect("reconnect_delay should fit in usize"),
2042            agent_clone,
2043            Some(user_agent_clone),
2044        );
2045        Arc::new(Self {
2046            common,
2047            is_connecting: Mutex::new(false),
2048            connection_streams: Mutex::new(HashMap::new()),
2049            configuration,
2050            stream_id_is_strictly_number: AtomicBool::new(false),
2051            url_paths,
2052        })
2053    }
2054
2055    /// Establishes a WebSocket connection for the given streams.
2056    ///
2057    /// This method attempts to connect to a WebSocket server using the connection pool.
2058    /// If a connection is already established or in progress, it returns immediately.
2059    ///
2060    /// # Arguments
2061    ///
2062    /// * `streams` - A vector of stream identifiers to connect to
2063    ///
2064    /// # Returns
2065    ///
2066    /// A `Result` indicating whether the connection was successful or an error occurred
2067    ///
2068    /// # Errors
2069    ///
2070    /// Returns a `WebsocketError` if the connection fails or times out after 10 seconds
2071    pub async fn connect(self: Arc<Self>, streams: Vec<String>) -> Result<(), WebsocketError> {
2072        if self.common.is_connected(None).await {
2073            info!("WebSocket connection already established");
2074            return Ok(());
2075        }
2076
2077        {
2078            let mut flag = self.is_connecting.lock().await;
2079            if *flag {
2080                info!("Already connecting...");
2081                return Ok(());
2082            }
2083            *flag = true;
2084        }
2085
2086        let handler: Arc<dyn WebsocketHandler> = self.clone();
2087        for conn in &self.common.connection_pool {
2088            conn.set_handler(handler.clone()).await;
2089        }
2090
2091        let base_pool_size = self.configuration.mode.pool_size();
2092
2093        let connect_fut = async {
2094            if self.url_paths.is_empty() {
2095                let url = self.prepare_url(&streams, None);
2096                self.common.clone().connect_pool(&url, None).await
2097            } else {
2098                let mut futures = Vec::with_capacity(self.url_paths.len());
2099
2100                for (i, path) in self.url_paths.iter().enumerate() {
2101                    let start = i * base_pool_size;
2102
2103                    let subset: Vec<Arc<WebsocketConnection>> = self
2104                        .common
2105                        .connection_pool
2106                        .iter()
2107                        .skip(start)
2108                        .take(base_pool_size)
2109                        .cloned()
2110                        .collect();
2111
2112                    if subset.len() != base_pool_size {
2113                        return Err(WebsocketError::ServerError(format!(
2114                            "connection_pool too small for url_paths: need {} per path, got {} for path index {}",
2115                            base_pool_size,
2116                            subset.len(),
2117                            i
2118                        )));
2119                    }
2120
2121                    for c in &subset {
2122                        let mut st = c.state.lock().await;
2123                        st.url_path = Some(path.clone());
2124                    }
2125
2126                    let url = self.prepare_url(&streams, Some(path.as_str()));
2127                    let common = self.common.clone();
2128
2129                    futures.push(async move { common.connect_pool(&url, Some(subset)).await });
2130                }
2131
2132                try_join_all(futures).await?;
2133                Ok(())
2134            }
2135        };
2136
2137        let connect_res = select! {
2138            () = sleep(Duration::from_secs(10)) => Err(WebsocketError::Timeout),
2139            r = connect_fut => r,
2140        };
2141
2142        {
2143            let mut flag = self.is_connecting.lock().await;
2144            *flag = false;
2145        }
2146
2147        connect_res
2148    }
2149
2150    /// Disconnects all WebSocket connections and clears associated state.
2151    ///
2152    /// # Returns
2153    ///
2154    /// A `Result` indicating whether the disconnection was successful or an error occurred
2155    ///
2156    /// # Errors
2157    ///
2158    /// Returns a `WebsocketError` if there are issues during the disconnection process
2159    ///
2160    /// # Side Effects
2161    ///
2162    /// - Clears stream callbacks for all connections
2163    /// - Clears pending subscriptions for all connections
2164    /// - Removes all connection stream mappings
2165    pub async fn disconnect(&self) -> Result<(), WebsocketError> {
2166        for connection in &self.common.connection_pool {
2167            let mut conn_state = connection.state.lock().await;
2168            conn_state.stream_callbacks.clear();
2169            conn_state.pending_subscriptions.clear();
2170        }
2171        self.connection_streams.lock().await.clear();
2172        self.common.disconnect().await
2173    }
2174
2175    /// Checks if the WebSocket connection is currently active.
2176    ///
2177    /// # Returns
2178    ///
2179    /// `true` if the WebSocket connection is established, `false` otherwise.
2180    pub async fn is_connected(&self) -> bool {
2181        self.common.is_connected(None).await
2182    }
2183
2184    /// Sends a ping to the WebSocket server to maintain the connection.
2185    ///
2186    /// This method delegates the ping operation to the underlying common WebSocket connection.
2187    /// It is typically used to keep the connection alive and check its status.
2188    ///
2189    /// # Side Effects
2190    ///
2191    /// Sends a ping request to the WebSocket server through the common connection.
2192    pub async fn ping_server(&self) {
2193        self.common.ping_server().await;
2194    }
2195
2196    /// Subscribes to multiple WebSocket streams, handling connection and queuing logic.
2197    ///
2198    /// # Arguments
2199    ///
2200    /// * `streams` - A vector of stream names to subscribe to
2201    /// * `id` - An optional request identifier for the subscription
2202    /// * `url_path` - An optional URL path for the subscription
2203    ///
2204    /// # Behavior
2205    ///
2206    /// - Filters out streams already subscribed
2207    /// - Assigns streams to appropriate connections
2208    /// - Handles subscription for active connections
2209    /// - Queues subscriptions for inactive connections
2210    ///
2211    /// # Side Effects
2212    ///
2213    /// - Sends subscription payloads for active connections
2214    /// - Adds pending subscriptions for inactive connections
2215    pub async fn subscribe(
2216        self: Arc<Self>,
2217        streams: Vec<String>,
2218        id: Option<StreamId>,
2219        url_path: Option<&str>,
2220    ) {
2221        let streams: Vec<String> = {
2222            let map = self.connection_streams.lock().await;
2223            streams
2224                .into_iter()
2225                .filter(|s| {
2226                    let key = self.stream_key(s, url_path);
2227                    !map.contains_key(&key)
2228                })
2229                .collect()
2230        };
2231
2232        if streams.is_empty() {
2233            return;
2234        }
2235
2236        let connection_streams = self.handle_stream_assignment(streams, url_path).await;
2237
2238        for (conn, assigned_streams) in connection_streams {
2239            if !self.common.is_connected(Some(&conn)).await {
2240                info!(
2241                    "Connection {} is not ready. Queuing subscription for streams: {:?}",
2242                    conn.id, assigned_streams
2243                );
2244
2245                let mut conn_state = conn.state.lock().await;
2246                conn_state
2247                    .pending_subscriptions
2248                    .extend(assigned_streams.iter().cloned());
2249
2250                continue;
2251            }
2252
2253            self.send_subscription_payload(&conn, &assigned_streams, id.clone());
2254        }
2255    }
2256
2257    /// Unsubscribes from specified WebSocket streams.
2258    ///
2259    /// # Arguments
2260    ///
2261    /// * `streams` - A vector of stream names to unsubscribe from
2262    /// * `id` - An optional request identifier for the unsubscription
2263    /// * `url_path` - An optional URL path for the unsubscription
2264    ///
2265    /// # Behavior
2266    ///
2267    /// - Validates the request identifier or generates a random one
2268    /// - Checks for active connections and subscribed streams
2269    /// - Sends unsubscribe payload for streams with active callbacks
2270    /// - Removes stream from connection streams and callbacks
2271    ///
2272    /// # Side Effects
2273    ///
2274    /// - Sends unsubscribe request to WebSocket server
2275    /// - Removes stream tracking from internal state
2276    ///
2277    /// # Async
2278    ///
2279    /// This method is asynchronous and requires `.await` when called
2280    ///
2281    /// # Panics
2282    ///
2283    /// This method may panic if the request identifier is not valid.
2284    ///
2285    pub async fn unsubscribe(
2286        &self,
2287        streams: Vec<String>,
2288        id: Option<StreamId>,
2289        url_path: Option<&str>,
2290    ) {
2291        let request_id = normalize_stream_id(
2292            id.clone(),
2293            self.stream_id_is_strictly_number.load(Ordering::Relaxed),
2294        );
2295
2296        for stream in streams {
2297            let key = self.stream_key(&stream, url_path);
2298            let maybe_conn = { self.connection_streams.lock().await.get(&key).cloned() };
2299
2300            let conn = match maybe_conn {
2301                Some(c) => {
2302                    if !self.common.is_connected(Some(&c)).await {
2303                        warn!(
2304                            "Stream {} not associated with an active connection.",
2305                            stream
2306                        );
2307                        continue;
2308                    }
2309                    c
2310                }
2311                None => {
2312                    warn!("Stream {} was not subscribed.", stream);
2313                    continue;
2314                }
2315            };
2316
2317            let has_callbacks = {
2318                let conn_state = conn.state.lock().await;
2319                conn_state
2320                    .stream_callbacks
2321                    .get(&key)
2322                    .is_some_and(|v| !v.is_empty())
2323            };
2324
2325            if has_callbacks {
2326                continue;
2327            }
2328
2329            let payload = json!({
2330                "method": "UNSUBSCRIBE",
2331                "params": [stream.clone()],
2332                "id": request_id,
2333            });
2334
2335            info!("UNSUBSCRIBE → {:?}", payload);
2336
2337            let common = Arc::clone(&self.common);
2338            let conn_clone = Arc::clone(&conn);
2339            let msg = serde_json::to_string(&payload).unwrap();
2340            spawn(async move {
2341                let _ = common
2342                    .send(msg, None, false, Duration::ZERO, Some(conn_clone))
2343                    .await;
2344            });
2345
2346            {
2347                let mut connection_streams = self.connection_streams.lock().await;
2348                connection_streams.remove(&key);
2349            }
2350            {
2351                let mut conn_state = conn.state.lock().await;
2352                conn_state.stream_callbacks.remove(&key);
2353            }
2354        }
2355    }
2356
2357    /// Checks if a specific stream is currently subscribed.
2358    ///
2359    /// # Arguments
2360    ///
2361    /// * `stream` - The stream identifier to check for subscription status
2362    ///
2363    /// # Returns
2364    ///
2365    /// `true` if the stream is subscribed, `false` otherwise
2366    ///
2367    /// # Async
2368    ///
2369    /// This method is asynchronous and requires `.await` when called
2370    pub async fn is_subscribed(&self, stream: &str) -> bool {
2371        let map = self.connection_streams.lock().await;
2372
2373        if map.contains_key(stream) {
2374            return true;
2375        }
2376
2377        let suffix = format!("::{}", stream);
2378        map.keys().any(|k| k.ends_with(&suffix))
2379    }
2380
2381    /// Generates a unique key for a stream based on its name and optional URL path.
2382    ///
2383    /// # Arguments
2384    ///
2385    /// * `stream` - The name of the stream
2386    /// * `url_path` - An optional URL path associated with the stream
2387    ///
2388    /// # Returns
2389    ///
2390    /// A `String` representing the unique key for the stream
2391    ///
2392    fn stream_key(&self, stream: &str, url_path: Option<&str>) -> String {
2393        match url_path {
2394            Some(p) if !p.is_empty() => format!("{p}::{stream}"),
2395            _ => stream.to_string(),
2396        }
2397    }
2398
2399    /// Prepares a WebSocket URL for streaming with optional stream names and time unit configuration.
2400    ///
2401    /// # Arguments
2402    ///
2403    /// * `streams` - A slice of stream names to be included in the URL
2404    /// * `url_path` - An optional path to append to the base WebSocket URL
2405    ///
2406    /// # Returns
2407    ///
2408    /// A fully constructed WebSocket URL with optional stream and time unit parameters
2409    ///
2410    /// # Notes
2411    ///
2412    /// - If no time unit is specified, the base URL is returned
2413    /// - Validates and appends the time unit parameter if provided and valid
2414    /// - Handles URL parameter separator based on existing query parameters
2415    fn prepare_url(&self, streams: &[String], url_path: Option<&str>) -> String {
2416        let mut url = format!(
2417            "{}/stream?streams={}",
2418            match url_path {
2419                Some(path) => format!(
2420                    "{}/{}",
2421                    self.configuration.ws_url.as_deref().unwrap_or(""),
2422                    path
2423                ),
2424                None => self
2425                    .configuration
2426                    .ws_url
2427                    .as_deref()
2428                    .unwrap_or("")
2429                    .to_string(),
2430            },
2431            streams.join("/")
2432        );
2433
2434        let time_unit = match &self.configuration.time_unit {
2435            Some(u) => u.to_string(),
2436            None => return url,
2437        };
2438
2439        match validate_time_unit(&time_unit) {
2440            Ok(Some(validated)) => {
2441                let sep = if url.contains('?') { '&' } else { '?' };
2442                url.push(sep);
2443                url.push_str("timeUnit=");
2444                url.push_str(validated);
2445            }
2446            Ok(None) => {}
2447            Err(e) => {
2448                error!("Invalid time unit provided: {:?}", e);
2449            }
2450        }
2451
2452        url
2453    }
2454
2455    /// Handles stream assignment by finding or creating WebSocket connections for a list of streams.
2456    ///
2457    /// This method attempts to assign streams to existing WebSocket connections or creates new
2458    /// connections if needed. It groups streams by their assigned connections and handles scenarios
2459    /// such as closed or pending reconnection connections.
2460    ///
2461    /// # Arguments
2462    ///
2463    /// * `streams` - A vector of stream names to be assigned
2464    /// * `url_path` - An optional URL path associated with the streams
2465    ///
2466    /// # Returns
2467    ///
2468    /// A vector of tuples containing WebSocket connections and their associated streams
2469    ///
2470    /// # Errors
2471    ///
2472    /// Returns an empty result if no connections can be established for the streams
2473    async fn handle_stream_assignment(
2474        &self,
2475        streams: Vec<String>,
2476        url_path: Option<&str>,
2477    ) -> Vec<(Arc<WebsocketConnection>, Vec<String>)> {
2478        let mut connection_streams: Vec<(String, Arc<WebsocketConnection>)> = Vec::new();
2479
2480        for stream in streams {
2481            let key = self.stream_key(&stream, url_path);
2482
2483            let mut conn_opt = {
2484                let map = self.connection_streams.lock().await;
2485                map.get(&key).cloned()
2486            };
2487
2488            let need_new = if let Some(conn) = &conn_opt {
2489                let state = conn.state.lock().await;
2490                state.close_initiated || state.reconnection_pending
2491            } else {
2492                true
2493            };
2494
2495            if need_new {
2496                match self.common.get_connection(true, url_path).await {
2497                    Ok(new_conn) => {
2498                        let mut map = self.connection_streams.lock().await;
2499                        map.insert(key.clone(), new_conn.clone());
2500                        conn_opt = Some(new_conn);
2501                    }
2502                    Err(err) => {
2503                        warn!(
2504                            "No available WebSocket connection to subscribe stream `{}` (key `{}`): {:?}",
2505                            stream, key, err
2506                        );
2507                        continue;
2508                    }
2509                }
2510            }
2511
2512            if let Some(conn) = conn_opt {
2513                {
2514                    let mut conn_state = conn.state.lock().await;
2515                    conn_state.stream_callbacks.entry(key.clone()).or_default();
2516                }
2517                connection_streams.push((stream, conn));
2518            }
2519        }
2520
2521        let mut groups: Vec<(Arc<WebsocketConnection>, Vec<String>)> = Vec::new();
2522        for (stream, conn) in connection_streams {
2523            if let Some((_, vec)) = groups.iter_mut().find(|(c, _)| Arc::ptr_eq(c, &conn)) {
2524                vec.push(stream);
2525            } else {
2526                groups.push((conn, vec![stream]));
2527            }
2528        }
2529
2530        groups
2531    }
2532
2533    /// Sends a WebSocket subscription payload for the specified streams.
2534    ///
2535    /// # Arguments
2536    ///
2537    /// * `connection` - The WebSocket connection to send the subscription on
2538    /// * `streams` - A vector of stream names to subscribe to
2539    /// * `id` - An optional request ID for the subscription (will be randomly generated if not provided)
2540    ///
2541    /// # Remarks
2542    ///
2543    /// This method constructs a SUBSCRIBE payload, logs it, and sends it asynchronously using the WebSocket connection.
2544    /// If serialization fails, an error is logged and the method returns without sending.
2545    fn send_subscription_payload(
2546        &self,
2547        connection: &Arc<WebsocketConnection>,
2548        streams: &Vec<String>,
2549        id: Option<StreamId>,
2550    ) {
2551        let request_id = normalize_stream_id(
2552            id.clone(),
2553            self.stream_id_is_strictly_number.load(Ordering::Relaxed),
2554        );
2555
2556        let payload = json!({
2557            "method": "SUBSCRIBE",
2558            "params": streams,
2559            "id": request_id,
2560        });
2561
2562        info!("SUBSCRIBE → {:?}", payload);
2563
2564        let common = Arc::clone(&self.common);
2565        let msg = match serde_json::to_string(&payload) {
2566            Ok(s) => s,
2567            Err(e) => {
2568                error!("Failed to serialize SUBSCRIBE payload: {}", e);
2569                return;
2570            }
2571        };
2572        let conn_clone = Arc::clone(connection);
2573
2574        spawn(async move {
2575            let _ = common
2576                .send(msg, None, false, Duration::ZERO, Some(conn_clone))
2577                .await;
2578        });
2579    }
2580}
2581
2582#[async_trait]
2583impl WebsocketHandler for WebsocketStreams {
2584    /// Handles the WebSocket connection opening by processing any pending subscriptions.
2585    ///
2586    /// This method is called when a WebSocket connection is established. It retrieves
2587    /// any pending stream subscriptions from the connection state and sends them
2588    /// immediately using the `send_subscription_payload` method.
2589    ///
2590    /// # Arguments
2591    ///
2592    /// * `_url` - The URL of the WebSocket connection (unused)
2593    /// * `connection` - The WebSocket connection that has just been opened
2594    ///
2595    /// # Remarks
2596    ///
2597    /// If there are any pending subscriptions, they are sent as a batch subscription
2598    /// payload. The method uses a lock to safely access and clear the pending subscriptions
2599    /// from the connection state.
2600    async fn on_open(&self, _url: String, connection: Arc<WebsocketConnection>) {
2601        let pending_subs: Vec<String> = {
2602            let mut conn_state = connection.state.lock().await;
2603            take(&mut conn_state.pending_subscriptions)
2604                .into_iter()
2605                .collect()
2606        };
2607
2608        if !pending_subs.is_empty() {
2609            info!("Processing queued subscriptions for connection");
2610            self.send_subscription_payload(&connection, &pending_subs, None);
2611        }
2612    }
2613
2614    /// Handles incoming WebSocket stream messages by parsing the JSON payload and invoking registered stream callbacks.
2615    ///
2616    /// This method processes WebSocket messages with a specific structure, extracting the stream name and data.
2617    /// It retrieves and executes any registered callbacks associated with the stream name.
2618    ///
2619    /// # Arguments
2620    ///
2621    /// * `data` - The raw WebSocket message as a JSON-formatted string
2622    /// * `connection` - The WebSocket connection through which the message was received
2623    ///
2624    /// # Behavior
2625    ///
2626    /// - Parses the JSON message
2627    /// - Extracts the stream name and data payload
2628    /// - Looks up and invokes any registered callbacks for the stream
2629    /// - Silently returns if message parsing or stream extraction fails
2630    async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>) {
2631        let msg: Value = match serde_json::from_str(&data) {
2632            Ok(v) => v,
2633            Err(err) => {
2634                error!(
2635                    "Failed to parse WebSocket stream message {} – {}",
2636                    data, err
2637                );
2638                return;
2639            }
2640        };
2641
2642        let (stream_name, payload) = match (
2643            msg.get("stream").and_then(Value::as_str),
2644            msg.get("data").cloned(),
2645        ) {
2646            (Some(name), Some(data)) => (name.to_string(), data),
2647            _ => return,
2648        };
2649
2650        let callbacks = {
2651            let conn_state = connection.state.lock().await;
2652            let key = self.stream_key(&stream_name, conn_state.url_path.as_deref());
2653            conn_state
2654                .stream_callbacks
2655                .get(&key)
2656                .cloned()
2657                .unwrap_or_else(Vec::new)
2658        };
2659
2660        for callback in callbacks {
2661            callback(&payload);
2662        }
2663    }
2664
2665    /// Retrieves the reconnection URL for a specific WebSocket connection by identifying all streams associated with that connection.
2666    ///
2667    /// # Arguments
2668    ///
2669    /// * `_default_url` - A default URL that can be used if no specific reconnection URL is determined
2670    /// * `connection` - The WebSocket connection for which to generate a reconnection URL
2671    ///
2672    /// # Returns
2673    ///
2674    /// A URL string that can be used to reconnect to the WebSocket, based on the streams associated with the given connection
2675    async fn get_reconnect_url(
2676        &self,
2677        _default_url: String,
2678        connection: Arc<WebsocketConnection>,
2679    ) -> String {
2680        let connection_streams = self.connection_streams.lock().await;
2681        let reconnect_streams = connection_streams
2682            .iter()
2683            .filter_map(|(key, conn_arc)| {
2684                if Arc::ptr_eq(conn_arc, &connection) {
2685                    let stream = match key.split_once("::") {
2686                        Some((_prefix, rest)) => rest.to_string(),
2687                        None => key.clone(),
2688                    };
2689                    Some(stream)
2690                } else {
2691                    None
2692                }
2693            })
2694            .collect::<Vec<_>>();
2695
2696        let url_path = {
2697            let st = connection.state.lock().await;
2698            st.url_path.as_deref().map(std::string::ToString::to_string)
2699        };
2700
2701        self.prepare_url(&reconnect_streams, url_path.as_deref())
2702    }
2703}
2704
2705pub struct WebsocketStream<T> {
2706    websocket_base: WebsocketBase,
2707    stream_or_id: String,
2708    url_path: Option<String>,
2709    callback: Mutex<Option<Arc<dyn Fn(&Value) + Send + Sync>>>,
2710    pub id: Option<StreamId>,
2711    _phantom: PhantomData<T>,
2712}
2713
2714impl<T> WebsocketStream<T>
2715where
2716    T: DeserializeOwned + Send + 'static,
2717{
2718    /// Registers a callback function for a specific event on the WebSocket stream.
2719    ///
2720    /// This method currently only supports the "message" event. When a message is received,
2721    /// the provided callback function will be invoked with the deserialized payload.
2722    ///
2723    /// # Arguments
2724    ///
2725    /// * `event` - The event type to listen for (currently only "message" is supported)
2726    /// * `callback_fn` - A function that will be called with the deserialized message payload
2727    ///
2728    /// # Errors
2729    ///
2730    /// Logs an error if the payload cannot be deserialized into the expected type
2731    ///
2732    /// # Examples
2733    ///
2734    ///
2735    /// stream.on("message", |data: `MyType`| {
2736    ///     // Handle the deserialized message
2737    /// });
2738    async fn on<F>(&self, event: &str, callback_fn: F)
2739    where
2740        F: Fn(T) + Send + Sync + 'static,
2741    {
2742        if event != "message" {
2743            return;
2744        }
2745
2746        let cb_wrapper: Arc<dyn Fn(&Value) + Send + Sync> =
2747            Arc::new(
2748                move |v: &Value| match serde_json::from_value::<T>(v.clone()) {
2749                    Ok(data) => callback_fn(data),
2750                    Err(e) => error!("Failed to deserialize stream payload: {:?}", e),
2751                },
2752            );
2753
2754        {
2755            let mut guard = self.callback.lock().await;
2756            *guard = Some(cb_wrapper.clone());
2757        }
2758
2759        match &self.websocket_base {
2760            WebsocketBase::WebsocketStreams(ws_streams) => {
2761                let key = ws_streams.stream_key(&self.stream_or_id, self.url_path.as_deref());
2762                let conn = {
2763                    let map = ws_streams.connection_streams.lock().await;
2764                    map.get(&key).cloned().expect("stream must be subscribed")
2765                };
2766
2767                {
2768                    let mut conn_state = conn.state.lock().await;
2769                    let entry = conn_state.stream_callbacks.entry(key).or_default();
2770
2771                    if !entry
2772                        .iter()
2773                        .any(|existing| Arc::ptr_eq(existing, &cb_wrapper))
2774                    {
2775                        entry.push(cb_wrapper);
2776                    }
2777                }
2778            }
2779            WebsocketBase::WebsocketApi(ws_api) => {
2780                let mut stream_callbacks = ws_api.stream_callbacks.lock().await;
2781                let entry = stream_callbacks
2782                    .entry(self.stream_or_id.clone())
2783                    .or_default();
2784
2785                if !entry
2786                    .iter()
2787                    .any(|existing| Arc::ptr_eq(existing, &cb_wrapper))
2788                {
2789                    entry.push(cb_wrapper);
2790                }
2791            }
2792        }
2793    }
2794
2795    /// Synchronously sets a message callback for the WebSocket stream on the current thread.
2796    ///
2797    /// # Arguments
2798    ///
2799    /// * `callback_fn` - A function that will be called with the deserialized message payload
2800    ///
2801    /// # Panics
2802    ///
2803    /// Panics if the thread runtime fails to be created or if the thread join fails
2804    ///
2805    /// # Examples
2806    ///
2807    ///
2808    /// let stream = `Arc::new(WebsocketStream::new())`;
2809    /// `stream.on_message(|data`: `MyType`| {
2810    ///     // Handle the deserialized message
2811    /// });
2812    ///
2813    pub fn on_message<F>(self: &Arc<Self>, callback_fn: F)
2814    where
2815        T: Send + Sync,
2816        F: Fn(T) + Send + Sync + 'static,
2817    {
2818        let handler: Arc<Self> = Arc::clone(self);
2819
2820        std::thread::spawn(move || {
2821            let rt = tokio::runtime::Builder::new_current_thread()
2822                .enable_all()
2823                .build()
2824                .expect("failed to build Tokio runtime");
2825
2826            rt.block_on(handler.on("message", callback_fn));
2827        })
2828        .join()
2829        .expect("on_message thread panicked");
2830    }
2831
2832    /// Unsubscribes from the current WebSocket stream and removes the associated callback.
2833    ///
2834    /// This method performs the following actions:
2835    /// - Removes the current callback associated with the stream
2836    /// - Removes the callback from the connection's stream callbacks
2837    /// - Asynchronously unsubscribes from the stream using the WebSocket streams base
2838    ///
2839    /// # Panics
2840    ///
2841    /// Panics if the stream is not subscribed to
2842    ///
2843    /// # Notes
2844    /// - If no callback is present, no action is taken
2845    /// - Spawns an asynchronous task to handle the unsubscription process
2846    pub async fn unsubscribe(&self) {
2847        let maybe_cb = {
2848            let mut guard = self.callback.lock().await;
2849            guard.take()
2850        };
2851
2852        if let Some(cb) = maybe_cb {
2853            match &self.websocket_base {
2854                WebsocketBase::WebsocketStreams(ws_streams) => {
2855                    let key = ws_streams.stream_key(&self.stream_or_id, self.url_path.as_deref());
2856                    let conn = {
2857                        let map = ws_streams.connection_streams.lock().await;
2858                        map.get(&key)
2859                            .cloned()
2860                            .expect("stream must have been subscribed")
2861                    };
2862
2863                    {
2864                        let mut conn_state = conn.state.lock().await;
2865                        if let Some(list) = conn_state.stream_callbacks.get_mut(&key) {
2866                            list.retain(|existing| !Arc::ptr_eq(existing, &cb));
2867                        }
2868                    }
2869
2870                    let stream = self.stream_or_id.clone();
2871                    let id = self.id.clone();
2872                    let url_path = self.url_path.clone();
2873                    let websocket_streams_base = Arc::clone(ws_streams);
2874                    spawn(async move {
2875                        websocket_streams_base
2876                            .unsubscribe(vec![stream], id, url_path.as_deref())
2877                            .await;
2878                    });
2879                }
2880                WebsocketBase::WebsocketApi(ws_api) => {
2881                    let mut stream_callbacks = ws_api.stream_callbacks.lock().await;
2882                    if let Some(list) = stream_callbacks.get_mut(&self.stream_or_id) {
2883                        list.retain(|existing| !Arc::ptr_eq(existing, &cb));
2884                    }
2885                }
2886            }
2887        }
2888    }
2889}
2890
2891/// Creates a new WebSocket stream handler for the specified stream or ID.
2892/// This function subscribes to the stream if the WebSocket base is of type `WebsocketStreams`.
2893///
2894/// # Arguments
2895///
2896/// * `websocket_base` - The base WebSocket instance (either `WebsocketStreams` or `WebsocketApi`)
2897/// * `stream_or_id` - The stream name or identifier to subscribe to
2898/// * `id` - An optional request identifier for the subscription
2899/// * `url_path` - An optional URL path for the subscription
2900///
2901/// # Returns
2902///
2903/// A new `WebsocketStream` instance.
2904///
2905pub async fn create_stream_handler<T>(
2906    websocket_base: WebsocketBase,
2907    stream_or_id: String,
2908    id: Option<StreamId>,
2909    url_path: Option<String>,
2910) -> Arc<WebsocketStream<T>>
2911where
2912    T: DeserializeOwned + Send + 'static,
2913{
2914    match &websocket_base {
2915        WebsocketBase::WebsocketStreams(ws_streams) => {
2916            ws_streams
2917                .clone()
2918                .subscribe(vec![stream_or_id.clone()], id.clone(), url_path.as_deref())
2919                .await;
2920        }
2921        WebsocketBase::WebsocketApi(_) => {}
2922    }
2923
2924    Arc::new(WebsocketStream {
2925        websocket_base,
2926        stream_or_id,
2927        url_path,
2928        id,
2929        callback: Mutex::new(None),
2930        _phantom: PhantomData,
2931    })
2932}
2933
2934#[cfg(test)]
2935mod tests {
2936    use crate::TOKIO_SHARED_RT;
2937    use crate::common::utils::{SignatureGenerator, build_user_agent};
2938    use crate::common::websocket::{
2939        PendingRequest, ReconnectEntry, SendWebsocketMessageResult, WebsocketApi, WebsocketBase,
2940        WebsocketCommon, WebsocketConnection, WebsocketEvent, WebsocketEventEmitter,
2941        WebsocketHandler, WebsocketMessageSendOptions, WebsocketMode, WebsocketSessionLogonReq,
2942        WebsocketStream, WebsocketStreams, create_stream_handler,
2943    };
2944    use crate::config::{ConfigurationWebsocketApi, ConfigurationWebsocketStreams, PrivateKey};
2945    use crate::errors::{WebsocketConnectionFailureReason, WebsocketError};
2946    use crate::models::{StreamId, TimeUnit};
2947    use async_trait::async_trait;
2948    use futures::{SinkExt, StreamExt};
2949    use http::header::USER_AGENT;
2950    use regex::Regex;
2951    use serde_json::{Value, json};
2952    use std::collections::{BTreeMap, HashSet};
2953    use std::marker::PhantomData;
2954    use std::net::SocketAddr;
2955    use std::sync::{
2956        Arc,
2957        atomic::{AtomicBool, AtomicUsize, Ordering},
2958    };
2959    use tokio::net::TcpListener;
2960    use tokio::sync::{
2961        Mutex,
2962        mpsc::{Receiver, unbounded_channel},
2963        oneshot,
2964    };
2965    use tokio::time::{Duration, advance, pause, resume, sleep, timeout};
2966    use tokio_tungstenite::{accept_async, accept_hdr_async, tungstenite, tungstenite::Message};
2967    use tungstenite::handshake::server::Request;
2968
2969    fn subscribe_events(common: &WebsocketCommon) -> Arc<Mutex<Vec<WebsocketEvent>>> {
2970        let events = Arc::new(Mutex::new(Vec::new()));
2971        let events_clone = events.clone();
2972        common.events.subscribe(move |event| {
2973            let events_clone = events_clone.clone();
2974            tokio::spawn(async move {
2975                events_clone.lock().await.push(event);
2976            });
2977        });
2978        events
2979    }
2980
2981    async fn create_connection(
2982        id: &str,
2983        has_writer: bool,
2984        reconnection_pending: bool,
2985        renewal_pending: bool,
2986        close_initiated: bool,
2987    ) -> Arc<WebsocketConnection> {
2988        let conn = WebsocketConnection::new(id);
2989        let mut st = conn.state.lock().await;
2990        st.reconnection_pending = reconnection_pending;
2991        st.renewal_pending = renewal_pending;
2992        st.close_initiated = close_initiated;
2993        if has_writer {
2994            let (tx, _) = unbounded_channel::<Message>();
2995            st.ws_write_tx = Some(tx);
2996        } else {
2997            st.ws_write_tx = None;
2998        }
2999        drop(st);
3000        conn
3001    }
3002
3003    fn create_websocket_api(
3004        time_unit: Option<TimeUnit>,
3005        mode: Option<WebsocketMode>,
3006        auto_session_relogon: Option<bool>,
3007    ) -> Arc<WebsocketApi> {
3008        let mode = mode.unwrap_or(WebsocketMode::Single);
3009        let auto_session_relogon = auto_session_relogon.unwrap_or(true);
3010        let sig_gen = SignatureGenerator::new(
3011            Some("api_secret".into()),
3012            None::<PrivateKey>,
3013            None::<String>,
3014        );
3015        let config = ConfigurationWebsocketApi {
3016            api_key: Some("api_key".into()),
3017            api_secret: Some("api_secret".into()),
3018            private_key: None,
3019            private_key_passphrase: None,
3020            ws_url: Some("wss://example.com".into()),
3021            mode,
3022            reconnect_delay: 1000,
3023            signature_gen: sig_gen,
3024            timeout: 500,
3025            time_unit,
3026            auto_session_relogon,
3027            agent: None,
3028            user_agent: build_user_agent("product"),
3029        };
3030        let conn1 = WebsocketConnection::new("c1");
3031        let conn2 = WebsocketConnection::new("c2");
3032        WebsocketApi::new(config, vec![conn1, conn2])
3033    }
3034
3035    fn create_websocket_streams(
3036        ws_url: Option<&str>,
3037        conns: Option<Vec<Arc<WebsocketConnection>>>,
3038        url_paths: Option<Vec<String>>,
3039    ) -> Arc<WebsocketStreams> {
3040        let mut connections: Vec<Arc<WebsocketConnection>> = vec![];
3041        let url_paths = url_paths.unwrap_or_default();
3042        if conns.is_none() {
3043            connections.push(WebsocketConnection::new("c1"));
3044            connections.push(WebsocketConnection::new("c2"));
3045        } else {
3046            connections = conns.expect("Expected connections to be set");
3047        }
3048        let config = ConfigurationWebsocketStreams {
3049            ws_url: Some(ws_url.unwrap_or("example.com").to_string()),
3050            mode: WebsocketMode::Single,
3051            reconnect_delay: 500,
3052            time_unit: None,
3053            agent: None,
3054            user_agent: build_user_agent("product"),
3055        };
3056        WebsocketStreams::new(config, connections, url_paths)
3057    }
3058
3059    fn subscribe_to_emitter(emitter: &WebsocketEventEmitter) -> Receiver<WebsocketEvent> {
3060        let (test_tx, test_rx) = tokio::sync::mpsc::channel(16);
3061        let _sub = emitter.subscribe(move |evt| {
3062            let _ = test_tx.try_send(evt);
3063        });
3064        test_rx
3065    }
3066
3067    async fn expect_websocket_event(rx: &mut Receiver<WebsocketEvent>) -> WebsocketEvent {
3068        timeout(Duration::from_millis(200), rx.recv())
3069            .await
3070            .expect("timed out waiting for event")
3071            .expect("subscriber channel closed")
3072    }
3073
3074    async fn eventually_async<F, Fut>(max_wait: Duration, mut f: F) -> bool
3075    where
3076        F: FnMut() -> Fut,
3077        Fut: std::future::Future<Output = bool>,
3078    {
3079        let start = tokio::time::Instant::now();
3080        while start.elapsed() < max_wait {
3081            if f().await {
3082                return true;
3083            }
3084            sleep(Duration::from_millis(20)).await;
3085        }
3086        false
3087    }
3088
3089    mod event_emitter {
3090        use super::*;
3091
3092        #[test]
3093        fn event_emitter_subscribe_and_emit() {
3094            TOKIO_SHARED_RT.block_on(async {
3095                let emitter = WebsocketEventEmitter::new();
3096                let (tx, rx) = oneshot::channel();
3097                let tx = Arc::new(std::sync::Mutex::new(Some(tx)));
3098                let tx_clone = tx.clone();
3099                let _sub = emitter.subscribe(move |event| {
3100                    if let Some(sender) = tx_clone.lock().unwrap().take() {
3101                        let _ = sender.send(event);
3102                    }
3103                });
3104                emitter.emit(&WebsocketEvent::Open);
3105                let received = timeout(Duration::from_millis(100), rx)
3106                    .await
3107                    .expect("timed out");
3108                assert_eq!(received, Ok(WebsocketEvent::Open));
3109            });
3110        }
3111
3112        #[test]
3113        fn single_subscriber_gets_event() {
3114            TOKIO_SHARED_RT.block_on(async {
3115                let emitter = WebsocketEventEmitter::new();
3116                let mut rx = subscribe_to_emitter(&emitter);
3117
3118                let e1 = WebsocketEvent::Open;
3119                emitter.emit(&e1);
3120
3121                let got = expect_websocket_event(&mut rx).await;
3122                assert_eq!(got, e1);
3123            });
3124        }
3125
3126        #[test]
3127        fn multiple_subscribers_get_event() {
3128            TOKIO_SHARED_RT.block_on(async {
3129                let emitter = WebsocketEventEmitter::new();
3130                let mut rx1 = subscribe_to_emitter(&emitter);
3131                let mut rx2 = subscribe_to_emitter(&emitter);
3132
3133                let e = WebsocketEvent::Message("hello".into());
3134                emitter.emit(&e);
3135
3136                assert_eq!(expect_websocket_event(&mut rx1).await, e.clone());
3137                assert_eq!(expect_websocket_event(&mut rx2).await, e);
3138            });
3139        }
3140
3141        #[test]
3142        fn closed_subscribers_are_pruned() {
3143            TOKIO_SHARED_RT.block_on(async {
3144                let emitter = WebsocketEventEmitter::new();
3145                let rx1 = subscribe_to_emitter(&emitter);
3146                let mut rx2 = subscribe_to_emitter(&emitter);
3147                drop(rx1);
3148
3149                let e = WebsocketEvent::Pong;
3150                emitter.emit(&e);
3151
3152                assert_eq!(expect_websocket_event(&mut rx2).await, e);
3153            });
3154        }
3155
3156        #[test]
3157        fn prune_on_error_does_not_hang() {
3158            TOKIO_SHARED_RT.block_on(async {
3159                let emitter = WebsocketEventEmitter::new();
3160                let rx = subscribe_to_emitter(&emitter);
3161                drop(rx);
3162
3163                let e = WebsocketEvent::Close(1000, "bye".into());
3164                emitter.emit(&e);
3165            });
3166        }
3167    }
3168
3169    mod websocket_common {
3170        use super::*;
3171
3172        mod initialisation {
3173            use super::*;
3174
3175            #[test]
3176            fn single_mode() {
3177                TOKIO_SHARED_RT.block_on(async {
3178                    let common = WebsocketCommon::new(vec![], WebsocketMode::Single, 0, None, None);
3179                    assert_eq!(common.connection_pool.len(), 1);
3180                });
3181            }
3182
3183            #[test]
3184            fn pool_mode() {
3185                TOKIO_SHARED_RT.block_on(async {
3186                    let common =
3187                        WebsocketCommon::new(vec![], WebsocketMode::Pool(3), 0, None, None);
3188                    assert_eq!(common.connection_pool.len(), 3);
3189                });
3190            }
3191        }
3192
3193        mod spawn_reconnect_loop {
3194            use super::*;
3195
3196            #[test]
3197            fn successful_reconnect_entry_triggers_init_connect() {
3198                TOKIO_SHARED_RT.block_on(async {
3199                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3200                    let addr = listener.local_addr().unwrap();
3201                    tokio::spawn(async move {
3202                        if let Ok((stream, _)) = listener.accept().await {
3203                            let mut ws = accept_async(stream).await.unwrap();
3204                            sleep(Duration::from_secs(5)).await;
3205                            let _ = ws.close(None).await;
3206                        }
3207                    });
3208
3209                    let conn = WebsocketConnection::new("c1");
3210                    let common = WebsocketCommon::new(
3211                        vec![conn.clone()],
3212                        WebsocketMode::Single,
3213                        10,
3214                        None,
3215                        None,
3216                    );
3217                    let url = format!("ws://{addr}");
3218                    common
3219                        .reconnect_tx
3220                        .send(ReconnectEntry {
3221                            connection_id: "c1".into(),
3222                            url: url.clone(),
3223                            is_renewal: false,
3224                        })
3225                        .await
3226                        .unwrap();
3227
3228                    let mut ok = false;
3229                    for _ in 0..100 {
3230                        if conn.state.lock().await.ws_write_tx.is_some() {
3231                            ok = true;
3232                            break;
3233                        }
3234                        sleep(Duration::from_millis(50)).await;
3235                    }
3236                    assert!(ok, "expected ws_write_tx to be Some after reconnect");
3237                });
3238            }
3239
3240            #[test]
3241            fn reconnect_entry_with_unknown_id_is_ignored() {
3242                TOKIO_SHARED_RT.block_on(async {
3243                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3244                    let addr = listener.local_addr().unwrap();
3245                    tokio::spawn(async move {
3246                        if let Ok((stream, _)) = listener.accept().await {
3247                            let mut ws = accept_async(stream).await.unwrap();
3248                            let _ = ws.close(None).await;
3249                        }
3250                    });
3251
3252                    let conn = WebsocketConnection::new("c1");
3253                    let common = WebsocketCommon::new(
3254                        vec![conn.clone()],
3255                        WebsocketMode::Single,
3256                        5,
3257                        None,
3258                        None,
3259                    );
3260                    let url = format!("ws://{addr}");
3261                    common
3262                        .reconnect_tx
3263                        .send(ReconnectEntry {
3264                            connection_id: "other".into(),
3265                            url,
3266                            is_renewal: false,
3267                        })
3268                        .await
3269                        .unwrap();
3270
3271                    sleep(Duration::from_secs(1)).await;
3272
3273                    let st = conn.state.lock().await;
3274                    assert!(st.ws_write_tx.is_none());
3275                });
3276            }
3277
3278            #[test]
3279            fn renewal_entries_bypass_initial_delay() {
3280                TOKIO_SHARED_RT.block_on(async {
3281                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3282                    let addr = listener.local_addr().unwrap();
3283                    tokio::spawn(async move {
3284                        if let Ok((stream, _)) = listener.accept().await {
3285                            let mut ws = accept_async(stream).await.unwrap();
3286                            let _ = ws.close(None).await;
3287                        }
3288                    });
3289
3290                    let conn = WebsocketConnection::new("renew");
3291                    let common = WebsocketCommon::new(
3292                        vec![conn.clone()],
3293                        WebsocketMode::Single,
3294                        200,
3295                        None,
3296                        None,
3297                    );
3298                    let url = format!("ws://{addr}");
3299                    common
3300                        .reconnect_tx
3301                        .send(ReconnectEntry {
3302                            connection_id: "renew".into(),
3303                            url: url.clone(),
3304                            is_renewal: true,
3305                        })
3306                        .await
3307                        .unwrap();
3308
3309                    sleep(Duration::from_secs(2)).await;
3310
3311                    let st = conn.state.lock().await;
3312
3313                    assert!(st.ws_write_tx.is_some());
3314                });
3315            }
3316
3317            #[test]
3318            fn non_renewal_entries_respect_initial_delay() {
3319                TOKIO_SHARED_RT.block_on(async {
3320                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3321                    let addr = listener.local_addr().unwrap();
3322                    tokio::spawn(async move {
3323                        if let Ok((stream, _)) = listener.accept().await {
3324                            let mut ws = accept_async(stream).await.unwrap();
3325                            sleep(Duration::from_secs(2)).await;
3326                            let _ = ws.close(None).await;
3327                        }
3328                    });
3329
3330                    let conn = WebsocketConnection::new("nonrenew");
3331                    let common = WebsocketCommon::new(
3332                        vec![conn.clone()],
3333                        WebsocketMode::Single,
3334                        200,
3335                        None,
3336                        None,
3337                    );
3338                    let url = format!("ws://{addr}");
3339                    common
3340                        .reconnect_tx
3341                        .send(ReconnectEntry {
3342                            connection_id: "nonrenew".into(),
3343                            url: url.clone(),
3344                            is_renewal: false,
3345                        })
3346                        .await
3347                        .unwrap();
3348
3349                    sleep(Duration::from_millis(50)).await;
3350                    assert!(conn.state.lock().await.ws_write_tx.is_none());
3351
3352                    let mut ok = false;
3353                    for _ in 0..200 {
3354                        if conn.state.lock().await.ws_write_tx.is_some() {
3355                            ok = true;
3356                            break;
3357                        }
3358                        sleep(Duration::from_millis(50)).await;
3359                    }
3360                    assert!(
3361                        ok,
3362                        "expected ws_write_tx to be Some after reconnect delay elapsed"
3363                    );
3364                });
3365            }
3366        }
3367
3368        mod spawn_renewal_loop {
3369            use super::*;
3370
3371            #[tokio::test]
3372            async fn scheduling_renewal_does_not_panic_for_known_connection() {
3373                pause();
3374
3375                let conn = WebsocketConnection::new("known");
3376                let common =
3377                    WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
3378                let url = "wss://example".to_string();
3379                common
3380                    .renewal_tx
3381                    .send((conn.id.clone(), url))
3382                    .await
3383                    .unwrap();
3384                advance(Duration::from_secs(23 * 60 * 60 + 1)).await;
3385                resume();
3386            }
3387
3388            #[tokio::test]
3389            async fn scheduling_renewal_ignored_for_unknown_connection() {
3390                pause();
3391
3392                let conn = WebsocketConnection::new("c1");
3393                let common =
3394                    WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
3395                common
3396                    .renewal_tx
3397                    .send(("other".into(), "u".into()))
3398                    .await
3399                    .unwrap();
3400                advance(Duration::from_secs(23 * 60 * 60 + 1)).await;
3401
3402                resume();
3403            }
3404        }
3405
3406        mod reconnect_regressions {
3407            use super::*;
3408
3409            #[test]
3410            fn init_connect_is_not_skipped_when_reconnection_pending() {
3411                TOKIO_SHARED_RT.block_on(async {
3412                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3413                    let addr = listener.local_addr().unwrap();
3414
3415                    tokio::spawn(async move {
3416                        if let Ok((stream, _)) = listener.accept().await {
3417                            tokio::spawn(async move {
3418                                let mut ws = accept_async(stream).await.unwrap();
3419                                sleep(Duration::from_millis(500)).await;
3420                                let _ = ws.close(None).await;
3421                            });
3422                        }
3423                    });
3424
3425                    let conn = WebsocketConnection::new("c-reconnect");
3426                    {
3427                        let mut st = conn.state.lock().await;
3428                        let (tx, _) = unbounded_channel::<Message>();
3429                        st.ws_write_tx = Some(tx);
3430                        st.reconnection_pending = true;
3431                    }
3432
3433                    let common = WebsocketCommon::new(
3434                        vec![conn.clone()],
3435                        WebsocketMode::Single,
3436                        0,
3437                        None,
3438                        None,
3439                    );
3440
3441                    let url = format!("ws://{addr}");
3442                    common
3443                        .clone()
3444                        .init_connect(&url, false, Some(conn.clone()))
3445                        .await
3446                        .unwrap();
3447
3448                    let ok = eventually_async(Duration::from_secs(2), || {
3449                        let conn = conn.clone();
3450                        async move {
3451                            let st = conn.state.lock().await;
3452                            st.ws_write_tx.is_some() && !st.reconnection_pending
3453                        }
3454                    })
3455                    .await;
3456
3457                    assert!(
3458                        ok,
3459                        "expected writer installed and reconnection_pending cleared"
3460                    );
3461                });
3462            }
3463
3464            #[test]
3465            fn pending_request_is_resolved_on_socket_drop() {
3466                TOKIO_SHARED_RT.block_on(async {
3467                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3468                    let addr = listener.local_addr().unwrap();
3469
3470                    tokio::spawn(async move {
3471                        if let Ok((stream, _)) = listener.accept().await {
3472                            let mut ws = accept_async(stream).await.unwrap();
3473                            let _ = ws.next().await;
3474                            let _ = ws.close(None).await;
3475                        }
3476                    });
3477
3478                    let conn = WebsocketConnection::new("c1");
3479                    let common = WebsocketCommon::new(
3480                        vec![conn.clone()],
3481                        WebsocketMode::Single,
3482                        0,
3483                        None,
3484                        None,
3485                    );
3486
3487                    let url = format!("ws://{addr}");
3488                    common
3489                        .clone()
3490                        .init_connect(&url, false, Some(conn.clone()))
3491                        .await
3492                        .unwrap();
3493
3494                    let rx = common
3495                        .send(
3496                            "{\"id\":\"req-1\",\"method\":\"PING\"}".to_string(),
3497                            Some("req-1".to_string()),
3498                            true,
3499                            Duration::from_millis(150),
3500                            Some(conn.clone()),
3501                        )
3502                        .await
3503                        .unwrap()
3504                        .expect("expected oneshot receiver");
3505
3506                    let res = timeout(Duration::from_secs(2), rx)
3507                        .await
3508                        .expect("did not resolve pending request")
3509                        .expect("oneshot cancelled");
3510
3511                    assert!(matches!(res, Err(WebsocketError::Timeout)));
3512
3513                    let ok = eventually_async(Duration::from_secs(1), || {
3514                        let conn = conn.clone();
3515                        async move { conn.state.lock().await.pending_requests.is_empty() }
3516                    })
3517                    .await;
3518
3519                    assert!(ok, "pending_requests should be drained");
3520                });
3521            }
3522        }
3523
3524        mod is_connection_ready {
3525            use super::*;
3526
3527            #[test]
3528            fn is_connection_ready() {
3529                TOKIO_SHARED_RT.block_on(async {
3530                    let conn = WebsocketConnection::new("c1");
3531                    let common = WebsocketCommon::new(
3532                        vec![conn.clone()],
3533                        WebsocketMode::Single,
3534                        0,
3535                        None,
3536                        None,
3537                    );
3538                    assert!(!common.is_connection_ready(&conn, false).await);
3539                    assert!(common.is_connection_ready(&conn, true).await);
3540                });
3541            }
3542
3543            #[test]
3544            fn connection_ready_basic() {
3545                TOKIO_SHARED_RT.block_on(async {
3546                    let conn = create_connection("c1", true, false, false, false).await;
3547                    let common = WebsocketCommon::new(
3548                        vec![conn.clone()],
3549                        WebsocketMode::Single,
3550                        0,
3551                        None,
3552                        None,
3553                    );
3554                    assert!(common.is_connection_ready(&conn, false).await);
3555                });
3556            }
3557
3558            #[test]
3559            fn connection_not_ready_without_writer() {
3560                TOKIO_SHARED_RT.block_on(async {
3561                    let conn = create_connection("c1", false, false, false, false).await;
3562                    let common = WebsocketCommon::new(
3563                        vec![conn.clone()],
3564                        WebsocketMode::Single,
3565                        0,
3566                        None,
3567                        None,
3568                    );
3569                    assert!(!common.is_connection_ready(&conn, false).await);
3570                    assert!(common.is_connection_ready(&conn, true).await);
3571                });
3572            }
3573
3574            #[test]
3575            fn connection_not_ready_when_flagged() {
3576                TOKIO_SHARED_RT.block_on(async {
3577                    let conn1 = create_connection("c1", true, true, false, false).await;
3578                    let conn2 = create_connection("c2", true, false, true, false).await;
3579                    let conn3 = create_connection("c3", true, false, false, true).await;
3580
3581                    let common = WebsocketCommon::new(
3582                        vec![conn1.clone(), conn2.clone(), conn3.clone()],
3583                        WebsocketMode::Pool(3),
3584                        0,
3585                        None,
3586                        None,
3587                    );
3588
3589                    assert!(!common.is_connection_ready(&conn1, false).await);
3590                    assert!(common.is_connection_ready(&conn2, false).await);
3591                    assert!(!common.is_connection_ready(&conn3, false).await);
3592                });
3593            }
3594        }
3595
3596        mod is_connected {
3597            use super::*;
3598
3599            #[test]
3600            fn with_pool_various_connections() {
3601                TOKIO_SHARED_RT.block_on(async {
3602                    let conn_a = create_connection("a", true, false, false, false).await;
3603                    let conn_b = create_connection("b", false, false, false, false).await;
3604                    let conn_c = create_connection("c", true, true, false, false).await;
3605                    let pool = vec![conn_a.clone(), conn_b.clone(), conn_c.clone()];
3606                    let common = WebsocketCommon::new(pool, WebsocketMode::Pool(3), 0, None, None);
3607
3608                    assert!(common.is_connected(None).await);
3609                    assert!(common.is_connected(Some(&conn_a)).await);
3610                    assert!(!common.is_connected(Some(&conn_b)).await);
3611                    assert!(!common.is_connected(Some(&conn_c)).await);
3612                });
3613            }
3614
3615            #[test]
3616            fn with_pool_all_bad_connections() {
3617                TOKIO_SHARED_RT.block_on(async {
3618                    let bad1 = create_connection("c1", false, false, false, false).await;
3619                    let bad2 = create_connection("c2", true, true, false, false).await;
3620                    let bad3 = create_connection("c3", true, false, false, true).await;
3621                    let common = WebsocketCommon::new(
3622                        vec![bad1, bad2, bad3],
3623                        WebsocketMode::Pool(3),
3624                        0,
3625                        None,
3626                        None,
3627                    );
3628
3629                    assert!(!common.is_connected(None).await);
3630                });
3631            }
3632
3633            #[test]
3634            fn with_pool_ignore_close_initiated() {
3635                TOKIO_SHARED_RT.block_on(async {
3636                    let good = create_connection("c1", true, false, false, false).await;
3637                    let closed = create_connection("c2", true, false, false, true).await;
3638                    let bad = create_connection("c3", false, false, false, false).await;
3639                    let common = WebsocketCommon::new(
3640                        vec![closed.clone(), good.clone(), bad.clone()],
3641                        WebsocketMode::Pool(3),
3642                        0,
3643                        None,
3644                        None,
3645                    );
3646
3647                    assert!(common.is_connected(None).await);
3648                    assert!(!common.is_connected(Some(&closed)).await);
3649                });
3650            }
3651        }
3652
3653        mod get_available_connections {
3654            use super::*;
3655
3656            #[test]
3657            fn single_mode() {
3658                TOKIO_SHARED_RT.block_on(async {
3659                    let common = WebsocketCommon::new(vec![], WebsocketMode::Single, 0, None, None);
3660                    let connections = common.get_available_connections(false, None).await;
3661                    assert_eq!(connections[0].id, common.connection_pool[0].id);
3662                });
3663            }
3664
3665            #[test]
3666            fn single_mode_with_url_path_does_not_force_first_connection() {
3667                TOKIO_SHARED_RT.block_on(async {
3668                    let conn1 = WebsocketConnection::new("c1");
3669                    let conn2 = WebsocketConnection::new("c2");
3670
3671                    let (tx2, _rx2) = unbounded_channel();
3672                    {
3673                        let mut s2 = conn2.state.lock().await;
3674                        s2.ws_write_tx = Some(tx2);
3675                        s2.url_path = Some("path1".to_string());
3676                    }
3677
3678                    {
3679                        let mut s1 = conn1.state.lock().await;
3680                        s1.url_path = Some("path1".to_string());
3681                    }
3682
3683                    let pool = vec![conn1.clone(), conn2.clone()];
3684                    let common = WebsocketCommon::new(pool, WebsocketMode::Single, 0, None, None);
3685
3686                    let connections = common.get_available_connections(false, Some("path1")).await;
3687
3688                    assert_eq!(connections.len(), 1);
3689                    assert_eq!(connections[0].id, "c2");
3690                });
3691            }
3692
3693            #[test]
3694            fn pool_mode_not_ready() {
3695                TOKIO_SHARED_RT.block_on(async {
3696                    let common =
3697                        WebsocketCommon::new(vec![], WebsocketMode::Pool(2), 0, None, None);
3698                    let connections = common.get_available_connections(false, None).await;
3699                    assert!(connections.is_empty());
3700                });
3701            }
3702
3703            #[test]
3704            fn pool_mode_with_ready() {
3705                TOKIO_SHARED_RT.block_on(async {
3706                    let conn1 = WebsocketConnection::new("c1");
3707                    let conn2 = WebsocketConnection::new("c2");
3708                    let (tx1, _rx1) = unbounded_channel();
3709                    {
3710                        let mut s1 = conn1.state.lock().await;
3711                        s1.ws_write_tx = Some(tx1);
3712                    }
3713                    let pool = vec![conn1.clone(), conn2.clone()];
3714                    let common = WebsocketCommon::new(pool, WebsocketMode::Pool(2), 0, None, None);
3715                    let connections = common.get_available_connections(false, None).await;
3716                    assert!(connections.len() == 1);
3717                });
3718            }
3719        }
3720
3721        mod get_connection {
3722            use super::*;
3723
3724            #[test]
3725            fn single_mode() {
3726                TOKIO_SHARED_RT.block_on(async {
3727                    let common = WebsocketCommon::new(vec![], WebsocketMode::Single, 0, None, None);
3728                    let conn = common
3729                        .get_connection(false, None)
3730                        .await
3731                        .expect("should get connection");
3732                    assert_eq!(conn.id, common.connection_pool[0].id);
3733                });
3734            }
3735
3736            #[test]
3737            fn single_mode_with_url_path_selects_matching_ready_connection() {
3738                TOKIO_SHARED_RT.block_on(async {
3739                    let conn1 = WebsocketConnection::new("c1");
3740                    let conn2 = WebsocketConnection::new("c2");
3741
3742                    let (tx1, _rx1) = unbounded_channel();
3743                    {
3744                        let mut s1 = conn1.state.lock().await;
3745                        s1.ws_write_tx = Some(tx1);
3746                        s1.url_path = Some("path2".to_string());
3747                    }
3748
3749                    let (tx2, _rx2) = unbounded_channel();
3750                    {
3751                        let mut s2 = conn2.state.lock().await;
3752                        s2.ws_write_tx = Some(tx2);
3753                        s2.url_path = Some("path1".to_string());
3754                    }
3755
3756                    let pool = vec![conn1.clone(), conn2.clone()];
3757                    let common = WebsocketCommon::new(pool, WebsocketMode::Single, 0, None, None);
3758
3759                    let chosen = common
3760                        .get_connection(false, Some("path1"))
3761                        .await
3762                        .expect("should get connection");
3763
3764                    assert_eq!(chosen.id, "c2");
3765                });
3766            }
3767
3768            #[test]
3769            fn pool_mode_not_ready() {
3770                TOKIO_SHARED_RT.block_on(async {
3771                    let common =
3772                        WebsocketCommon::new(vec![], WebsocketMode::Pool(2), 0, None, None);
3773                    let result = common.get_connection(false, None).await;
3774                    assert!(matches!(
3775                        result,
3776                        Err(crate::errors::WebsocketError::NotConnected)
3777                    ));
3778                });
3779            }
3780
3781            #[test]
3782            fn pool_mode_with_ready() {
3783                TOKIO_SHARED_RT.block_on(async {
3784                    let conn1 = WebsocketConnection::new("c1");
3785                    let conn2 = WebsocketConnection::new("c2");
3786                    let (tx1, _rx1) = unbounded_channel();
3787                    {
3788                        let mut s1 = conn1.state.lock().await;
3789                        s1.ws_write_tx = Some(tx1);
3790                    }
3791                    let pool = vec![conn1.clone(), conn2.clone()];
3792                    let common = WebsocketCommon::new(pool, WebsocketMode::Pool(2), 0, None, None);
3793                    let result = common.get_connection(false, None).await;
3794                    assert!(result.is_ok());
3795                    let chosen = result.unwrap();
3796                    assert_eq!(chosen.id, conn1.id);
3797                });
3798            }
3799
3800            #[test]
3801            fn pool_mode_with_url_path_filters_connections() {
3802                TOKIO_SHARED_RT.block_on(async {
3803                    let conn1 = WebsocketConnection::new("c1");
3804                    let conn2 = WebsocketConnection::new("c2");
3805
3806                    let (tx1, _rx1) = unbounded_channel();
3807                    {
3808                        let mut s1 = conn1.state.lock().await;
3809                        s1.ws_write_tx = Some(tx1);
3810                        s1.url_path = Some("path1".to_string());
3811                    }
3812
3813                    let (tx2, _rx2) = unbounded_channel();
3814                    {
3815                        let mut s2 = conn2.state.lock().await;
3816                        s2.ws_write_tx = Some(tx2);
3817                        s2.url_path = Some("path2".to_string());
3818                    }
3819
3820                    let pool = vec![conn1.clone(), conn2.clone()];
3821                    let common = WebsocketCommon::new(pool, WebsocketMode::Pool(2), 0, None, None);
3822
3823                    let chosen = common
3824                        .get_connection(false, Some("path2"))
3825                        .await
3826                        .expect("should pick ready connection for path2");
3827
3828                    assert_eq!(chosen.id, "c2");
3829                });
3830            }
3831
3832            #[test]
3833            fn url_path_no_match_returns_not_connected() {
3834                TOKIO_SHARED_RT.block_on(async {
3835                    let conn1 = WebsocketConnection::new("c1");
3836                    let (tx1, _rx1) = unbounded_channel();
3837                    {
3838                        let mut s1 = conn1.state.lock().await;
3839                        s1.ws_write_tx = Some(tx1);
3840                        s1.url_path = Some("path1".to_string());
3841                    }
3842
3843                    let pool = vec![conn1.clone()];
3844                    let common = WebsocketCommon::new(pool, WebsocketMode::Pool(1), 0, None, None);
3845
3846                    let result = common.get_connection(false, Some("path2")).await;
3847                    assert!(matches!(
3848                        result,
3849                        Err(crate::errors::WebsocketError::NotConnected)
3850                    ));
3851                });
3852            }
3853        }
3854
3855        mod close_connection_gracefully {
3856            use super::*;
3857
3858            #[tokio::test]
3859            async fn waits_for_pending_requests_then_closes() {
3860                pause();
3861
3862                let conn = WebsocketConnection::new("c1");
3863                let (tx, mut rx) = unbounded_channel::<Message>();
3864                let (req_tx, _req_rx) = oneshot::channel();
3865                {
3866                    let mut st = conn.state.lock().await;
3867                    st.pending_requests
3868                        .insert("r".to_string(), PendingRequest { completion: req_tx });
3869                }
3870                let common =
3871                    WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
3872                let close_fut = common.close_connection_gracefully(tx.clone(), conn.clone());
3873                advance(Duration::from_secs(1)).await;
3874                {
3875                    let mut st = conn.state.lock().await;
3876                    st.pending_requests.clear();
3877                }
3878                conn.drain_notify.notify_waiters();
3879                advance(Duration::from_secs(1)).await;
3880                close_fut.await.unwrap();
3881                match rx.try_recv() {
3882                    Ok(Message::Close(_)) => {}
3883                    other => panic!("expected Close, got {other:?}"),
3884                }
3885
3886                resume();
3887            }
3888
3889            #[tokio::test]
3890            async fn force_closes_after_timeout() {
3891                pause();
3892
3893                let conn = WebsocketConnection::new("c2");
3894                let (tx, mut rx) = unbounded_channel::<Message>();
3895                let (req_tx, _req_rx) = oneshot::channel();
3896                {
3897                    let mut st = conn.state.lock().await;
3898                    st.pending_requests.insert(
3899                        "request_id".to_string(),
3900                        PendingRequest { completion: req_tx },
3901                    );
3902                }
3903                let common =
3904                    WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
3905                let close_fut = common.close_connection_gracefully(tx.clone(), conn.clone());
3906                advance(Duration::from_secs(30)).await;
3907                close_fut.await.unwrap();
3908                match rx.try_recv() {
3909                    Ok(Message::Close(_)) => {}
3910                    other => panic!("expected Close on timeout, got {other:?}"),
3911                }
3912
3913                resume();
3914            }
3915        }
3916
3917        mod get_reconnect_url {
3918            use super::*;
3919
3920            struct DummyHandler {
3921                url: String,
3922            }
3923
3924            #[async_trait::async_trait]
3925            impl WebsocketHandler for DummyHandler {
3926                async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {}
3927                async fn on_message(&self, _data: String, _connection: Arc<WebsocketConnection>) {}
3928                async fn get_reconnect_url(
3929                    &self,
3930                    _default_url: String,
3931                    _connection: Arc<WebsocketConnection>,
3932                ) -> String {
3933                    self.url.clone()
3934                }
3935            }
3936
3937            #[test]
3938            fn returns_default_when_no_handler() {
3939                TOKIO_SHARED_RT.block_on(async {
3940                    let conn = WebsocketConnection::new("c1");
3941                    let common = WebsocketCommon::new(
3942                        vec![conn.clone()],
3943                        WebsocketMode::Single,
3944                        0,
3945                        None,
3946                        None,
3947                    );
3948                    let default = "wss://default".to_string();
3949                    let result = common.get_reconnect_url(&default, conn.clone()).await;
3950                    assert_eq!(result, default);
3951                });
3952            }
3953
3954            #[test]
3955            fn returns_handler_url_when_set() {
3956                TOKIO_SHARED_RT.block_on(async {
3957                    let conn = WebsocketConnection::new("c2");
3958                    let handler = Arc::new(DummyHandler {
3959                        url: "wss://custom".into(),
3960                    });
3961                    conn.set_handler(handler).await;
3962                    let common = WebsocketCommon::new(
3963                        vec![conn.clone()],
3964                        WebsocketMode::Single,
3965                        0,
3966                        None,
3967                        None,
3968                    );
3969                    let default = "wss://default".to_string();
3970                    let result = common.get_reconnect_url(&default, conn.clone()).await;
3971                    assert_eq!(result, "wss://custom");
3972                });
3973            }
3974        }
3975
3976        mod on_open {
3977            use super::*;
3978
3979            struct DummyHandler {
3980                called: Arc<Mutex<bool>>,
3981                opened_url: Arc<Mutex<Option<String>>>,
3982            }
3983
3984            #[async_trait]
3985            impl WebsocketHandler for DummyHandler {
3986                async fn on_open(&self, url: String, _connection: Arc<WebsocketConnection>) {
3987                    let mut flag = self.called.lock().await;
3988                    *flag = true;
3989                    let mut store = self.opened_url.lock().await;
3990                    *store = Some(url);
3991                }
3992                async fn on_message(&self, _data: String, _connection: Arc<WebsocketConnection>) {}
3993                async fn get_reconnect_url(
3994                    &self,
3995                    default_url: String,
3996                    _connection: Arc<WebsocketConnection>,
3997                ) -> String {
3998                    default_url
3999                }
4000            }
4001
4002            #[test]
4003            fn emits_open_and_calls_handler() {
4004                TOKIO_SHARED_RT.block_on(async {
4005                    let conn = WebsocketConnection::new("c1");
4006                    let called = Arc::new(Mutex::new(false));
4007                    let opened_url = Arc::new(Mutex::new(None));
4008                    let handler = Arc::new(DummyHandler {
4009                        called: called.clone(),
4010                        opened_url: opened_url.clone(),
4011                    });
4012
4013                    conn.set_handler(handler.clone()).await;
4014                    let common = WebsocketCommon::new(
4015                        vec![conn.clone()],
4016                        WebsocketMode::Single,
4017                        0,
4018                        None,
4019                        None,
4020                    );
4021                    let events = subscribe_events(&common);
4022                    common
4023                        .on_open("wss://example.com".into(), conn.clone(), None)
4024                        .await;
4025
4026                    sleep(std::time::Duration::from_millis(10)).await;
4027
4028                    let evs = events.lock().await;
4029                    assert!(evs.iter().any(|e| matches!(e, WebsocketEvent::Open)));
4030                    assert!(*called.lock().await);
4031                    assert_eq!(
4032                        opened_url.lock().await.as_deref(),
4033                        Some("wss://example.com")
4034                    );
4035                });
4036            }
4037
4038            #[test]
4039            fn handles_renewal_pending_and_closes_old_writer() {
4040                TOKIO_SHARED_RT.block_on(async {
4041                    let conn = WebsocketConnection::new("c2");
4042                    let (old_tx, mut old_rx) = unbounded_channel::<Message>();
4043                    {
4044                        let mut st = conn.state.lock().await;
4045                        st.renewal_pending = true;
4046                    }
4047                    let common = WebsocketCommon::new(
4048                        vec![conn.clone()],
4049                        WebsocketMode::Single,
4050                        0,
4051                        None,
4052                        None,
4053                    );
4054                    common
4055                        .on_open("url".into(), conn.clone(), Some(old_tx.clone()))
4056                        .await;
4057                    assert!(!conn.state.lock().await.renewal_pending);
4058                    match old_rx.try_recv() {
4059                        Ok(Message::Close(_)) => {}
4060                        other => panic!("expected Close, got {other:?}"),
4061                    }
4062                });
4063            }
4064        }
4065
4066        mod on_message {
4067            use super::*;
4068
4069            struct DummyHandler {
4070                called_with: Arc<Mutex<Vec<String>>>,
4071            }
4072
4073            #[async_trait]
4074            impl WebsocketHandler for DummyHandler {
4075                async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {}
4076                async fn on_message(&self, data: String, _connection: Arc<WebsocketConnection>) {
4077                    self.called_with.lock().await.push(data);
4078                }
4079                async fn get_reconnect_url(
4080                    &self,
4081                    default_url: String,
4082                    _connection: Arc<WebsocketConnection>,
4083                ) -> String {
4084                    default_url
4085                }
4086            }
4087
4088            #[test]
4089            fn emits_message_event_without_handler() {
4090                TOKIO_SHARED_RT.block_on(async {
4091                    let conn = WebsocketConnection::new("c1");
4092                    let common = WebsocketCommon::new(
4093                        vec![conn.clone()],
4094                        WebsocketMode::Single,
4095                        0,
4096                        None,
4097                        None,
4098                    );
4099                    let events = subscribe_events(&common);
4100                    common.on_message("msg".into(), conn.clone()).await;
4101
4102                    sleep(Duration::from_millis(10)).await;
4103
4104                    let locked = events.lock().await;
4105                    assert!(
4106                        locked
4107                            .iter()
4108                            .any(|e| matches!(e, WebsocketEvent::Message(m) if m == "msg"))
4109                    );
4110                });
4111            }
4112
4113            #[test]
4114            fn calls_handler_and_emits_message() {
4115                TOKIO_SHARED_RT.block_on(async {
4116                    let conn = WebsocketConnection::new("c2");
4117                    let called = Arc::new(Mutex::new(Vec::new()));
4118                    let handler = Arc::new(DummyHandler {
4119                        called_with: called.clone(),
4120                    });
4121                    conn.set_handler(handler.clone()).await;
4122
4123                    let common = WebsocketCommon::new(
4124                        vec![conn.clone()],
4125                        WebsocketMode::Single,
4126                        0,
4127                        None,
4128                        None,
4129                    );
4130                    let events = subscribe_events(&common);
4131                    common.on_message("msg".into(), conn.clone()).await;
4132
4133                    sleep(Duration::from_millis(10)).await;
4134
4135                    let evs = events.lock().await;
4136                    assert!(
4137                        evs.iter()
4138                            .any(|e| matches!(e, WebsocketEvent::Message(m) if m == "msg"))
4139                    );
4140                    let msgs = called.lock().await;
4141                    assert_eq!(msgs.as_slice(), &["msg".to_string()]);
4142                });
4143            }
4144        }
4145
4146        mod create_websocket {
4147            use super::*;
4148
4149            #[test]
4150            fn successful_connection() {
4151                TOKIO_SHARED_RT.block_on(async {
4152                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4153                    let addr: SocketAddr = listener.local_addr().unwrap();
4154
4155                    let expected_ua = build_user_agent("product");
4156                    let expected_ua_clone = expected_ua.clone();
4157
4158                    tokio::spawn(async move {
4159                        if let Ok((stream, _)) = listener.accept().await {
4160                            let callback = |req: &Request, resp| {
4161                                let got = req
4162                                    .headers()
4163                                    .get(USER_AGENT)
4164                                    .expect("no USER_AGENT header in WS handshake")
4165                                    .to_str()
4166                                    .expect("invalid USER_AGENT header");
4167                                assert_eq!(got, expected_ua_clone, "User-Agent mismatch");
4168                                Ok(resp)
4169                            };
4170                            let _ = accept_hdr_async(stream, callback).await.unwrap();
4171                        }
4172                    });
4173
4174                    let url = format!("ws://{addr}");
4175                    let res =
4176                        WebsocketCommon::create_websocket(&url, None, Some(expected_ua)).await;
4177                    assert!(res.is_ok(), "handshake failed: {res:?}");
4178                });
4179            }
4180
4181            #[test]
4182            fn invalid_url_returns_handshake_error() {
4183                TOKIO_SHARED_RT.block_on(async {
4184                    let res =
4185                        WebsocketCommon::create_websocket("not-a-valid-url", None, None).await;
4186                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4187                });
4188            }
4189
4190            #[test]
4191            fn unreachable_host_returns_handshake_error() {
4192                TOKIO_SHARED_RT.block_on(async {
4193                    let res =
4194                        WebsocketCommon::create_websocket("ws://127.0.0.1:1", None, None).await;
4195                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4196                });
4197            }
4198        }
4199
4200        mod connect_pool {
4201            use super::*;
4202
4203            #[test]
4204            fn connects_all_in_pool() {
4205                TOKIO_SHARED_RT.block_on(async {
4206                    let pool_size = 3;
4207                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4208                    let addr = listener.local_addr().unwrap();
4209                    tokio::spawn(async move {
4210                        for _ in 0..pool_size {
4211                            if let Ok((stream, _)) = listener.accept().await {
4212                                tokio::spawn(async move {
4213                                    let mut ws = accept_async(stream).await.unwrap();
4214                                    sleep(Duration::from_millis(500)).await;
4215                                    let _ = ws.close(None).await;
4216                                });
4217                            }
4218                        }
4219                    });
4220                    let conns: Vec<Arc<WebsocketConnection>> = (0..pool_size)
4221                        .map(|i| WebsocketConnection::new(format!("c{i}")))
4222                        .collect();
4223                    let common = WebsocketCommon::new(
4224                        conns.clone(),
4225                        WebsocketMode::Pool(pool_size),
4226                        0,
4227                        None,
4228                        None,
4229                    );
4230                    let url = format!("ws://{addr}");
4231                    common.clone().connect_pool(&url, None).await.unwrap();
4232                    for conn in conns {
4233                        let mut ok = false;
4234                        for _ in 0..100 {
4235                            if conn.state.lock().await.ws_write_tx.is_some() {
4236                                ok = true;
4237                                break;
4238                            }
4239                            sleep(Duration::from_millis(50)).await;
4240                        }
4241                        assert!(ok, "expected ws_write_tx Some after connect");
4242                    }
4243                });
4244            }
4245
4246            #[test]
4247            fn fails_if_any_refused() {
4248                TOKIO_SHARED_RT.block_on(async {
4249                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4250                    let addr = listener.local_addr().unwrap();
4251                    let pool_size = 3;
4252                    tokio::spawn(async move {
4253                        for _ in 0..2 {
4254                            if let Ok((stream, _)) = listener.accept().await {
4255                                let mut ws = accept_async(stream).await.unwrap();
4256                                let _ = ws.close(None).await;
4257                            }
4258                        }
4259                    });
4260                    let mut conns = Vec::new();
4261                    let valid_url = format!("ws://{addr}");
4262                    for i in 0..2 {
4263                        conns.push(WebsocketConnection::new(format!("c{i}")));
4264                    }
4265                    conns.push(WebsocketConnection::new("bad"));
4266                    let common = WebsocketCommon::new(
4267                        conns.clone(),
4268                        WebsocketMode::Pool(pool_size),
4269                        0,
4270                        None,
4271                        None,
4272                    );
4273                    let res = common.clone().connect_pool(&valid_url, None).await;
4274                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4275                });
4276            }
4277
4278            #[test]
4279            fn fails_on_invalid_url() {
4280                TOKIO_SHARED_RT.block_on(async {
4281                    let conns = vec![WebsocketConnection::new("c1")];
4282                    let common = WebsocketCommon::new(conns, WebsocketMode::Pool(1), 0, None, None);
4283                    let res = common.connect_pool("not-a-url", None).await;
4284                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4285                });
4286            }
4287
4288            #[test]
4289            fn fails_if_mixed_success_and_invalid_url() {
4290                TOKIO_SHARED_RT.block_on(async {
4291                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4292                    let addr = listener.local_addr().unwrap();
4293                    tokio::spawn(async move {
4294                        if let Ok((stream, _)) = listener.accept().await {
4295                            let mut ws = accept_async(stream).await.unwrap();
4296                            let _ = ws.close(None).await;
4297                        }
4298                    });
4299                    let good = WebsocketConnection::new("good");
4300                    let bad = WebsocketConnection::new("bad");
4301                    let common = WebsocketCommon::new(
4302                        vec![good, bad],
4303                        WebsocketMode::Pool(2),
4304                        0,
4305                        None,
4306                        None,
4307                    );
4308                    let url = format!("ws://{addr}");
4309                    let res = common.connect_pool(&url, None).await;
4310                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4311                });
4312            }
4313
4314            #[test]
4315            fn init_connect_invoked_for_each() {
4316                TOKIO_SHARED_RT.block_on(async {
4317                    let pool_size = 2;
4318                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4319                    let addr = listener.local_addr().unwrap();
4320                    tokio::spawn(async move {
4321                        for _ in 0..pool_size {
4322                            if let Ok((stream, _)) = listener.accept().await {
4323                                tokio::spawn(async move {
4324                                    let mut ws = accept_async(stream).await.unwrap();
4325                                    sleep(Duration::from_millis(500)).await;
4326                                    let _ = ws.close(None).await;
4327                                });
4328                            }
4329                        }
4330                    });
4331                    let conns: Vec<Arc<WebsocketConnection>> = (0..pool_size)
4332                        .map(|i| WebsocketConnection::new(format!("c{i}")))
4333                        .collect();
4334                    let common = WebsocketCommon::new(
4335                        conns.clone(),
4336                        WebsocketMode::Pool(pool_size),
4337                        0,
4338                        None,
4339                        None,
4340                    );
4341                    let url = format!("ws://{addr}");
4342                    common.clone().connect_pool(&url, None).await.unwrap();
4343                    for conn in conns {
4344                        let mut ok = false;
4345                        for _ in 0..100 {
4346                            if conn.state.lock().await.ws_write_tx.is_some() {
4347                                ok = true;
4348                                break;
4349                            }
4350                            sleep(Duration::from_millis(25)).await;
4351                        }
4352                        assert!(ok, "expected ws_write_tx Some for {}", conn.id);
4353                    }
4354                });
4355            }
4356
4357            #[test]
4358            fn single_mode_uses_first_connection() {
4359                TOKIO_SHARED_RT.block_on(async {
4360                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4361                    let addr = listener.local_addr().unwrap();
4362                    tokio::spawn(async move {
4363                        if let Ok((stream, _)) = listener.accept().await {
4364                            let mut ws = accept_async(stream).await.unwrap();
4365                            let _ = ws.close(None).await;
4366                        }
4367                    });
4368                    let conn = WebsocketConnection::new("c1");
4369                    let common = WebsocketCommon::new(
4370                        vec![conn.clone()],
4371                        WebsocketMode::Single,
4372                        0,
4373                        None,
4374                        None,
4375                    );
4376                    let url = format!("ws://{addr}");
4377                    common.connect_pool(&url, None).await.unwrap();
4378                    let ok = eventually_async(Duration::from_secs(5), || {
4379                        let conn = conn.clone();
4380                        async move { conn.state.lock().await.ws_write_tx.is_some() }
4381                    })
4382                    .await;
4383
4384                    assert!(ok, "single mode did not select first connection");
4385                });
4386            }
4387
4388            #[test]
4389            fn empty_subset_is_ok_and_connects_none() {
4390                TOKIO_SHARED_RT.block_on(async {
4391                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4392                    let addr = listener.local_addr().unwrap();
4393
4394                    tokio::spawn(async move {
4395                        let _ = addr;
4396                    });
4397
4398                    let c1 = WebsocketConnection::new("c1");
4399                    let c2 = WebsocketConnection::new("c2");
4400
4401                    let common = WebsocketCommon::new(
4402                        vec![c1.clone(), c2.clone()],
4403                        WebsocketMode::Pool(2),
4404                        0,
4405                        None,
4406                        None,
4407                    );
4408
4409                    let url = format!("ws://{addr}");
4410                    common
4411                        .clone()
4412                        .connect_pool(&url, Some(vec![]))
4413                        .await
4414                        .unwrap();
4415
4416                    assert!(c1.state.lock().await.ws_write_tx.is_none());
4417                    assert!(c2.state.lock().await.ws_write_tx.is_none());
4418                });
4419            }
4420        }
4421
4422        mod init_connect {
4423            use super::*;
4424
4425            #[test]
4426            fn pool_mode_none_connection_uses_first() {
4427                TOKIO_SHARED_RT.block_on(async {
4428                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4429                    let addr = listener.local_addr().unwrap();
4430                    tokio::spawn(async move {
4431                        for _ in 0..2 {
4432                            if let Ok((stream, _)) = listener.accept().await {
4433                                let mut ws = accept_async(stream).await.unwrap();
4434                                ws.close(None).await.ok();
4435                            }
4436                        }
4437                    });
4438
4439                    let c1 = WebsocketConnection::new("c1");
4440                    let c2 = WebsocketConnection::new("c2");
4441                    let common = WebsocketCommon::new(
4442                        vec![c1.clone(), c2.clone()],
4443                        WebsocketMode::Pool(2),
4444                        0,
4445                        None,
4446                        None,
4447                    );
4448                    let url = format!("ws://{addr}");
4449
4450                    common
4451                        .clone()
4452                        .init_connect(&url, false, None)
4453                        .await
4454                        .unwrap();
4455
4456                    let ok = eventually_async(Duration::from_secs(5), || {
4457                        let conn1 = c1.clone();
4458                        async move { conn1.state.lock().await.ws_write_tx.is_some() }
4459                    })
4460                    .await;
4461
4462                    assert!(ok, "first connection was never selected");
4463                    assert!(c2.state.lock().await.ws_write_tx.is_none());
4464                });
4465            }
4466
4467            #[test]
4468            fn writer_channel_can_send_text() {
4469                TOKIO_SHARED_RT.block_on(async {
4470                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4471                    let addr = listener.local_addr().unwrap();
4472                    let received = Arc::new(Mutex::new(None::<String>));
4473                    let received_clone = received.clone();
4474
4475                    tokio::spawn(async move {
4476                        if let Ok((stream, _)) = listener.accept().await {
4477                            let mut ws = accept_async(stream).await.unwrap();
4478                            if let Some(Ok(Message::Text(txt))) = ws.next().await {
4479                                *received_clone.lock().await = Some(txt.to_string());
4480                            }
4481                            ws.close(None).await.ok();
4482                        }
4483                    });
4484
4485                    let conn = WebsocketConnection::new("cw");
4486                    let common = WebsocketCommon::new(
4487                        vec![conn.clone()],
4488                        WebsocketMode::Single,
4489                        0,
4490                        None,
4491                        None,
4492                    );
4493                    let url = format!("ws://{addr}");
4494                    common
4495                        .clone()
4496                        .init_connect(&url, false, Some(conn.clone()))
4497                        .await
4498                        .unwrap();
4499
4500                    let tx = conn.state.lock().await.ws_write_tx.clone().unwrap();
4501                    tx.send(Message::Text("ping".into())).unwrap();
4502
4503                    sleep(Duration::from_millis(50)).await;
4504
4505                    let lock = received.lock().await;
4506                    assert_eq!(lock.as_deref(), Some("ping"));
4507                });
4508            }
4509
4510            #[test]
4511            fn does_not_skip_when_reconnection_pending_even_if_writer_exists() {
4512                TOKIO_SHARED_RT.block_on(async {
4513                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4514                    let addr = listener.local_addr().unwrap();
4515
4516                    tokio::spawn(async move {
4517                        if let Ok((stream, _)) = listener.accept().await {
4518                            let mut ws = accept_async(stream).await.unwrap();
4519                            let _ = ws.close(None).await;
4520                        }
4521                    });
4522
4523                    let conn = WebsocketConnection::new("c-reconnect");
4524                    {
4525                        let mut st = conn.state.lock().await;
4526                        let (tx, _) = unbounded_channel::<Message>();
4527                        st.ws_write_tx = Some(tx);
4528                        st.reconnection_pending = true;
4529                    }
4530
4531                    let common = WebsocketCommon::new(
4532                        vec![conn.clone()],
4533                        WebsocketMode::Single,
4534                        0,
4535                        None,
4536                        None,
4537                    );
4538
4539                    let url = format!("ws://{addr}");
4540                    common
4541                        .clone()
4542                        .init_connect(&url, false, Some(conn.clone()))
4543                        .await
4544                        .unwrap();
4545
4546                    let st = conn.state.lock().await;
4547                    assert!(
4548                        st.ws_write_tx.is_some(),
4549                        "writer should be set after connect"
4550                    );
4551                    assert!(
4552                        !st.reconnection_pending,
4553                        "reconnection_pending should be cleared after successful connect"
4554                    );
4555                });
4556            }
4557
4558            #[test]
4559            fn responds_to_ping_with_pong() {
4560                TOKIO_SHARED_RT.block_on(async {
4561                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4562                    let addr = listener.local_addr().unwrap();
4563
4564                    let saw_pong = Arc::new(Mutex::new(false));
4565                    let saw_pong2 = saw_pong.clone();
4566
4567                    tokio::spawn(async move {
4568                        if let Ok((stream, _)) = listener.accept().await {
4569                            let mut ws = accept_async(stream).await.unwrap();
4570                            ws.send(Message::Ping(vec![1, 2, 3].into())).await.unwrap();
4571                            if let Some(Ok(Message::Pong(payload))) = ws.next().await {
4572                                if payload[..] == [1, 2, 3] {
4573                                    *saw_pong2.lock().await = true;
4574                                }
4575                            }
4576                            let _ = ws.close(None).await;
4577                        }
4578                    });
4579
4580                    let conn = WebsocketConnection::new("c-ping");
4581                    let common = WebsocketCommon::new(
4582                        vec![conn.clone()],
4583                        WebsocketMode::Single,
4584                        0,
4585                        None,
4586                        None,
4587                    );
4588                    let url = format!("ws://{addr}");
4589                    common
4590                        .clone()
4591                        .init_connect(&url, false, Some(conn))
4592                        .await
4593                        .unwrap();
4594
4595                    sleep(Duration::from_millis(50)).await;
4596
4597                    assert!(*saw_pong.lock().await, "server should have seen a Pong");
4598                });
4599            }
4600
4601            #[test]
4602            fn handshake_error_on_invalid_url() {
4603                TOKIO_SHARED_RT.block_on(async {
4604                    let conn = WebsocketConnection::new("c-invalid");
4605                    let common = WebsocketCommon::new(
4606                        vec![conn.clone()],
4607                        WebsocketMode::Single,
4608                        0,
4609                        None,
4610                        None,
4611                    );
4612                    let res = common
4613                        .clone()
4614                        .init_connect("not-a-url", false, Some(conn.clone()))
4615                        .await;
4616                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
4617                });
4618            }
4619
4620            #[test]
4621            fn skip_if_writer_exists_and_not_renewal() {
4622                TOKIO_SHARED_RT.block_on(async {
4623                    let conn = WebsocketConnection::new("c-writer");
4624                    let (tx, mut rx) = unbounded_channel::<Message>();
4625                    {
4626                        let mut st = conn.state.lock().await;
4627                        st.ws_write_tx = Some(tx.clone());
4628                    }
4629                    let common = WebsocketCommon::new(
4630                        vec![conn.clone()],
4631                        WebsocketMode::Single,
4632                        0,
4633                        None,
4634                        None,
4635                    );
4636                    let res = common
4637                        .clone()
4638                        .init_connect("ws://127.0.0.1:1", false, Some(conn.clone()))
4639                        .await;
4640
4641                    assert!(res.is_ok());
4642                    assert!(rx.try_recv().is_err());
4643                });
4644            }
4645
4646            #[test]
4647            fn short_circuit_on_already_renewing() {
4648                TOKIO_SHARED_RT.block_on(async {
4649                    let conn = WebsocketConnection::new("c-renew");
4650                    {
4651                        let mut st = conn.state.lock().await;
4652                        st.renewal_pending = true;
4653                    }
4654                    let common = WebsocketCommon::new(
4655                        vec![conn.clone()],
4656                        WebsocketMode::Single,
4657                        0,
4658                        None,
4659                        None,
4660                    );
4661                    let res = common
4662                        .clone()
4663                        .init_connect("ws://127.0.0.1:1", true, Some(conn.clone()))
4664                        .await;
4665
4666                    assert!(res.is_ok());
4667                    assert!(conn.state.lock().await.ws_write_tx.is_none());
4668                });
4669            }
4670
4671            #[test]
4672            fn is_renewal_true_sets_and_clears_flag() {
4673                TOKIO_SHARED_RT.block_on(async {
4674                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4675                    let addr = listener.local_addr().unwrap();
4676                    tokio::spawn(async move {
4677                        if let Ok((stream, _)) = listener.accept().await {
4678                            let mut ws = accept_async(stream).await.unwrap();
4679                            let _ = ws.close(None).await;
4680                        }
4681                    });
4682
4683                    let conn = WebsocketConnection::new("c-new-renew");
4684                    let common = WebsocketCommon::new(
4685                        vec![conn.clone()],
4686                        WebsocketMode::Single,
4687                        0,
4688                        None,
4689                        None,
4690                    );
4691                    let url = format!("ws://{addr}");
4692                    let res = common
4693                        .clone()
4694                        .init_connect(&url, true, Some(conn.clone()))
4695                        .await;
4696
4697                    assert!(res.is_ok());
4698
4699                    {
4700                        let st = conn.state.lock().await;
4701                        assert!(st.ws_write_tx.is_some(), "writer should be set");
4702                        assert!(
4703                            st.renewal_pending,
4704                            "renewal_pending must be true until on_open"
4705                        );
4706                    }
4707
4708                    common.on_open(url.clone(), conn.clone(), None).await;
4709
4710                    let st = conn.state.lock().await;
4711                    assert!(
4712                        !st.renewal_pending,
4713                        "renewal_pending should be cleared in on_open"
4714                    );
4715                });
4716            }
4717
4718            #[test]
4719            fn default_connection_selected_when_none_passed() {
4720                TOKIO_SHARED_RT.block_on(async {
4721                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4722                    let addr = listener.local_addr().unwrap();
4723                    tokio::spawn(async move {
4724                        if let Ok((stream, _)) = listener.accept().await {
4725                            let mut ws = accept_async(stream).await.unwrap();
4726                            let _ = ws.close(None).await;
4727                        }
4728                    });
4729                    let conn = WebsocketConnection::new("c-default");
4730                    let common = WebsocketCommon::new(
4731                        vec![conn.clone()],
4732                        WebsocketMode::Single,
4733                        0,
4734                        None,
4735                        None,
4736                    );
4737                    let url = format!("ws://{addr}");
4738                    let res = common.clone().init_connect(&url, false, None).await;
4739
4740                    assert!(res.is_ok());
4741                    let ok = eventually_async(Duration::from_secs(5), || {
4742                        let conn = conn.clone();
4743                        async move { conn.state.lock().await.ws_write_tx.is_some() }
4744                    })
4745                    .await;
4746
4747                    assert!(ok, "default connection was never selected");
4748                });
4749            }
4750
4751            #[test]
4752            fn schedules_reconnect_on_abnormal_close() {
4753                TOKIO_SHARED_RT.block_on(async {
4754                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4755                    let addr = listener.local_addr().unwrap();
4756                    tokio::spawn(async move {
4757                        if let Ok((stream, _)) = listener.accept().await {
4758                            let mut ws = accept_async(stream).await.unwrap();
4759                            ws.close(Some(tungstenite::protocol::CloseFrame {
4760                                code: tungstenite::protocol::frame::coding::CloseCode::Abnormal,
4761                                reason: "oops".into(),
4762                            }))
4763                            .await
4764                            .ok();
4765                        }
4766                    });
4767                    let conn = WebsocketConnection::new("c-close");
4768                    let common = WebsocketCommon::new(
4769                        vec![conn.clone()],
4770                        WebsocketMode::Single,
4771                        5_000,
4772                        None,
4773                        None,
4774                    );
4775                    let url = format!("ws://{addr}");
4776                    common
4777                        .clone()
4778                        .init_connect(&url, false, Some(conn.clone()))
4779                        .await
4780                        .unwrap();
4781
4782                    sleep(Duration::from_millis(50)).await;
4783
4784                    let st = conn.state.lock().await;
4785                    assert!(
4786                        st.reconnection_pending,
4787                        "expected reconnection_pending to be true after abnormal close"
4788                    );
4789                    assert!(
4790                        st.ws_write_tx.is_none(),
4791                        "ws_write_tx should be cleared when scheduling a reconnect"
4792                    );
4793                });
4794            }
4795        }
4796
4797        mod disconnect {
4798            use super::*;
4799
4800            #[test]
4801            fn returns_ok_when_no_connections_are_ready() {
4802                TOKIO_SHARED_RT.block_on(async {
4803                    let conn = WebsocketConnection::new("c1");
4804                    let common = WebsocketCommon::new(
4805                        vec![conn.clone()],
4806                        WebsocketMode::Single,
4807                        0,
4808                        None,
4809                        None,
4810                    );
4811                    let res = common.disconnect().await;
4812
4813                    assert!(res.is_ok());
4814                    assert!(!conn.state.lock().await.close_initiated);
4815                });
4816            }
4817
4818            #[test]
4819            fn closes_all_ready_connections() {
4820                TOKIO_SHARED_RT.block_on(async {
4821                    let conn1 = WebsocketConnection::new("c1");
4822                    let conn2 = WebsocketConnection::new("c2");
4823                    let (tx1, mut rx1) = unbounded_channel::<Message>();
4824                    let (tx2, mut rx2) = unbounded_channel::<Message>();
4825                    {
4826                        let mut s1 = conn1.state.lock().await;
4827                        s1.ws_write_tx = Some(tx1);
4828                    }
4829                    {
4830                        let mut s2 = conn2.state.lock().await;
4831                        s2.ws_write_tx = Some(tx2);
4832                    }
4833                    let common = WebsocketCommon::new(
4834                        vec![conn1.clone(), conn2.clone()],
4835                        WebsocketMode::Pool(2),
4836                        0,
4837                        None,
4838                        None,
4839                    );
4840                    let fut = common.disconnect();
4841
4842                    sleep(Duration::from_millis(50)).await;
4843
4844                    fut.await.unwrap();
4845
4846                    assert!(conn1.state.lock().await.close_initiated);
4847                    assert!(conn2.state.lock().await.close_initiated);
4848
4849                    {
4850                        let st = conn1.state.lock().await;
4851                        assert!(!st.is_session_logged_on, "conn1 should be logged out");
4852                        assert!(st.session_logon_req.is_none(), "conn1 req cleared");
4853                    }
4854                    {
4855                        let st = conn2.state.lock().await;
4856                        assert!(!st.is_session_logged_on, "conn2 should be logged out");
4857                        assert!(st.session_logon_req.is_none(), "conn2 req cleared");
4858                    }
4859
4860                    match (rx1.try_recv(), rx2.try_recv()) {
4861                        (Ok(Message::Close(_)), Ok(Message::Close(_))) => {}
4862                        other => panic!("expected two Close frames, got {other:?}"),
4863                    }
4864                });
4865            }
4866
4867            #[test]
4868            fn does_not_mark_close_initiated_if_no_writer() {
4869                TOKIO_SHARED_RT.block_on(async {
4870                    let conn = WebsocketConnection::new("c-new");
4871                    let common = WebsocketCommon::new(
4872                        vec![conn.clone()],
4873                        WebsocketMode::Single,
4874                        0,
4875                        None,
4876                        None,
4877                    );
4878                    common.disconnect().await.unwrap();
4879
4880                    assert!(!conn.state.lock().await.close_initiated);
4881                });
4882            }
4883
4884            #[test]
4885            fn mixed_pool_marks_all_and_closes_only_writers() {
4886                TOKIO_SHARED_RT.block_on(async {
4887                    let conn_w = WebsocketConnection::new("with");
4888                    let conn_wo = WebsocketConnection::new("without");
4889                    let (tx, mut rx) = unbounded_channel::<Message>();
4890                    {
4891                        let mut st = conn_w.state.lock().await;
4892                        st.ws_write_tx = Some(tx);
4893                    }
4894                    let common = WebsocketCommon::new(
4895                        vec![conn_w.clone(), conn_wo.clone()],
4896                        WebsocketMode::Pool(2),
4897                        0,
4898                        None,
4899                        None,
4900                    );
4901                    let fut = common.disconnect();
4902
4903                    sleep(Duration::from_millis(50)).await;
4904
4905                    fut.await.unwrap();
4906
4907                    assert!(conn_w.state.lock().await.close_initiated);
4908                    assert!(conn_wo.state.lock().await.close_initiated);
4909                    assert!(matches!(rx.try_recv(), Ok(Message::Close(_))));
4910                });
4911            }
4912
4913            #[test]
4914            fn after_disconnect_not_connected() {
4915                TOKIO_SHARED_RT.block_on(async {
4916                    let conn = WebsocketConnection::new("c1");
4917                    let (tx, mut _rx) = unbounded_channel::<Message>();
4918                    {
4919                        let mut st = conn.state.lock().await;
4920                        st.ws_write_tx = Some(tx);
4921                    }
4922                    let common = WebsocketCommon::new(
4923                        vec![conn.clone()],
4924                        WebsocketMode::Single,
4925                        0,
4926                        None,
4927                        None,
4928                    );
4929                    common.disconnect().await.unwrap();
4930                    assert!(!common.is_connected(Some(&conn)).await);
4931                });
4932            }
4933        }
4934
4935        mod ping_server {
4936            use super::*;
4937
4938            #[test]
4939            fn sends_ping_to_all_ready_connections() {
4940                TOKIO_SHARED_RT.block_on(async {
4941                    let mut conns = Vec::new();
4942                    for i in 0..3 {
4943                        let conn = WebsocketConnection::new(format!("c{i}"));
4944                        let (tx, rx) = unbounded_channel::<Message>();
4945                        {
4946                            let mut st = conn.state.lock().await;
4947                            st.ws_write_tx = Some(tx);
4948                        }
4949                        conns.push((conn, rx));
4950                    }
4951                    let common = WebsocketCommon::new(
4952                        conns.iter().map(|(c, _)| c.clone()).collect(),
4953                        WebsocketMode::Pool(3),
4954                        0,
4955                        None,
4956                        None,
4957                    );
4958                    common.ping_server().await;
4959                    for (_, mut rx) in conns {
4960                        match rx.try_recv() {
4961                            Ok(Message::Ping(payload)) if payload.is_empty() => {}
4962                            other => panic!("expected empty-payload Ping, got {other:?}"),
4963                        }
4964                    }
4965                });
4966            }
4967
4968            #[test]
4969            fn skips_not_ready_and_partial() {
4970                TOKIO_SHARED_RT.block_on(async {
4971                    let ready = WebsocketConnection::new("ready");
4972                    let not_ready = WebsocketConnection::new("not-ready");
4973                    let (tx_r, mut rx_r) = unbounded_channel::<Message>();
4974                    {
4975                        let mut st = ready.state.lock().await;
4976                        st.ws_write_tx = Some(tx_r);
4977                    }
4978                    {
4979                        let mut st = not_ready.state.lock().await;
4980                        st.ws_write_tx = None;
4981                    }
4982                    let common = WebsocketCommon::new(
4983                        vec![ready.clone(), not_ready.clone()],
4984                        WebsocketMode::Pool(2),
4985                        0,
4986                        None,
4987                        None,
4988                    );
4989                    common.ping_server().await;
4990                    match rx_r.try_recv() {
4991                        Ok(Message::Ping(payload)) if payload.is_empty() => {}
4992                        other => panic!("expected Ping on ready, got {other:?}"),
4993                    }
4994                });
4995            }
4996
4997            #[test]
4998            fn no_ping_when_flags_block() {
4999                TOKIO_SHARED_RT.block_on(async {
5000                    let conn = WebsocketConnection::new("c1");
5001                    let (tx, mut rx) = unbounded_channel::<Message>();
5002                    {
5003                        let mut st = conn.state.lock().await;
5004                        st.ws_write_tx = Some(tx);
5005                        st.reconnection_pending = true;
5006                    }
5007                    let common = WebsocketCommon::new(
5008                        vec![conn.clone()],
5009                        WebsocketMode::Single,
5010                        0,
5011                        None,
5012                        None,
5013                    );
5014                    common.ping_server().await;
5015                    assert!(rx.try_recv().is_err());
5016                });
5017            }
5018        }
5019
5020        mod send {
5021            use super::*;
5022
5023            #[test]
5024            fn round_robin_send_without_specific() {
5025                TOKIO_SHARED_RT.block_on(async {
5026                    let conn1 = WebsocketConnection::new("c1");
5027                    let conn2 = WebsocketConnection::new("c2");
5028                    let (tx1, mut rx1) = unbounded_channel::<Message>();
5029                    let (tx2, mut rx2) = unbounded_channel::<Message>();
5030                    {
5031                        let mut s1 = conn1.state.lock().await;
5032                        s1.ws_write_tx = Some(tx1);
5033                    }
5034                    {
5035                        let mut s2 = conn2.state.lock().await;
5036                        s2.ws_write_tx = Some(tx2);
5037                    }
5038                    let common = WebsocketCommon::new(
5039                        vec![conn1.clone(), conn2.clone()],
5040                        WebsocketMode::Pool(2),
5041                        0,
5042                        None,
5043                        None,
5044                    );
5045
5046                    let res1 = common
5047                        .send("a".into(), None, false, Duration::from_secs(1), None)
5048                        .await
5049                        .unwrap();
5050                    assert!(res1.is_none());
5051
5052                    let res2 = common
5053                        .send("b".into(), None, false, Duration::from_secs(1), None)
5054                        .await
5055                        .unwrap();
5056                    assert!(res2.is_none());
5057
5058                    assert_eq!(
5059                        if let Message::Text(t) = rx1.try_recv().unwrap() {
5060                            t
5061                        } else {
5062                            panic!()
5063                        },
5064                        "a"
5065                    );
5066                    assert_eq!(
5067                        if let Message::Text(t) = rx2.try_recv().unwrap() {
5068                            t
5069                        } else {
5070                            panic!()
5071                        },
5072                        "b"
5073                    );
5074                });
5075            }
5076
5077            #[test]
5078            fn round_robin_skips_not_ready() {
5079                TOKIO_SHARED_RT.block_on(async {
5080                    let conn1 = WebsocketConnection::new("c1");
5081                    let conn2 = WebsocketConnection::new("c2");
5082                    let (tx2, mut rx2) = unbounded_channel::<Message>();
5083                    {
5084                        let mut s1 = conn1.state.lock().await;
5085                        s1.ws_write_tx = None;
5086                    }
5087                    {
5088                        let mut s2 = conn2.state.lock().await;
5089                        s2.ws_write_tx = Some(tx2);
5090                    }
5091                    let common = WebsocketCommon::new(
5092                        vec![conn1.clone(), conn2.clone()],
5093                        WebsocketMode::Pool(2),
5094                        0,
5095                        None,
5096                        None,
5097                    );
5098                    let res = common
5099                        .send("bar".into(), None, false, Duration::from_secs(1), None)
5100                        .await
5101                        .unwrap();
5102                    assert!(res.is_none());
5103                    match rx2.try_recv().unwrap() {
5104                        Message::Text(t) => assert_eq!(t, "bar"),
5105                        other => panic!("unexpected {other:?}"),
5106                    }
5107                });
5108            }
5109
5110            #[test]
5111            fn sync_send_on_specific_connection() {
5112                TOKIO_SHARED_RT.block_on(async {
5113                    let conn1 = WebsocketConnection::new("c1");
5114                    let conn2 = WebsocketConnection::new("c2");
5115                    let (tx2, mut rx2) = unbounded_channel::<Message>();
5116                    {
5117                        let mut st = conn2.state.lock().await;
5118                        st.ws_write_tx = Some(tx2);
5119                    }
5120                    let common = WebsocketCommon::new(
5121                        vec![conn1.clone(), conn2.clone()],
5122                        WebsocketMode::Pool(2),
5123                        0,
5124                        None,
5125                        None,
5126                    );
5127                    let res = common
5128                        .send(
5129                            "payload".into(),
5130                            Some("id".into()),
5131                            false,
5132                            Duration::from_secs(1),
5133                            Some(conn2.clone()),
5134                        )
5135                        .await
5136                        .unwrap();
5137                    assert!(res.is_none());
5138                    match rx2.try_recv() {
5139                        Ok(Message::Text(t)) => assert_eq!(t, "payload"),
5140                        other => panic!("expected Text, got {other:?}"),
5141                    }
5142                });
5143            }
5144
5145            #[test]
5146            fn sync_send_with_id_does_not_insert_pending() {
5147                TOKIO_SHARED_RT.block_on(async {
5148                    let conn = WebsocketConnection::new("c1");
5149                    let (tx, mut rx) = unbounded_channel::<Message>();
5150                    {
5151                        let mut st = conn.state.lock().await;
5152                        st.ws_write_tx = Some(tx);
5153                    }
5154                    let common = WebsocketCommon::new(
5155                        vec![conn.clone()],
5156                        WebsocketMode::Single,
5157                        0,
5158                        None,
5159                        None,
5160                    );
5161                    let res = common
5162                        .send(
5163                            "msg".into(),
5164                            Some("id".into()),
5165                            false,
5166                            Duration::from_secs(1),
5167                            Some(conn.clone()),
5168                        )
5169                        .await
5170                        .unwrap();
5171                    assert!(res.is_none());
5172                    assert!(conn.state.lock().await.pending_requests.is_empty());
5173                    match rx.try_recv().unwrap() {
5174                        Message::Text(t) => assert_eq!(t, "msg"),
5175                        other => panic!("unexpected {other:?}"),
5176                    }
5177                });
5178            }
5179
5180            #[test]
5181            fn sync_send_error_if_not_ready() {
5182                TOKIO_SHARED_RT.block_on(async {
5183                    let conn = WebsocketConnection::new("c1");
5184                    let common = WebsocketCommon::new(
5185                        vec![conn.clone()],
5186                        WebsocketMode::Single,
5187                        0,
5188                        None,
5189                        None,
5190                    );
5191                    let err = common
5192                        .send(
5193                            "msg".into(),
5194                            Some("id".into()),
5195                            false,
5196                            Duration::from_secs(1),
5197                            Some(conn.clone()),
5198                        )
5199                        .await
5200                        .unwrap_err();
5201                    assert!(matches!(err, WebsocketError::NotConnected));
5202                });
5203            }
5204
5205            #[test]
5206            fn sync_send_error_when_no_ready() {
5207                TOKIO_SHARED_RT.block_on(async {
5208                    let conn = WebsocketConnection::new("c1");
5209                    let common = WebsocketCommon::new(
5210                        vec![conn.clone()],
5211                        WebsocketMode::Single,
5212                        0,
5213                        None,
5214                        None,
5215                    );
5216                    let err = common
5217                        .send("msg".into(), None, false, Duration::from_secs(1), None)
5218                        .await
5219                        .unwrap_err();
5220                    assert!(matches!(err, WebsocketError::NotConnected));
5221                });
5222            }
5223
5224            #[test]
5225            fn async_send_and_receive() {
5226                TOKIO_SHARED_RT.block_on(async {
5227                    let conn = WebsocketConnection::new("c1");
5228                    let (tx, mut rx) = unbounded_channel::<Message>();
5229                    {
5230                        let mut st = conn.state.lock().await;
5231                        st.ws_write_tx = Some(tx);
5232                    }
5233                    let common = WebsocketCommon::new(
5234                        vec![conn.clone()],
5235                        WebsocketMode::Single,
5236                        0,
5237                        None,
5238                        None,
5239                    );
5240                    let fut = common
5241                        .send(
5242                            "hello".into(),
5243                            Some("id".into()),
5244                            true,
5245                            Duration::from_secs(5),
5246                            Some(conn.clone()),
5247                        )
5248                        .await
5249                        .unwrap()
5250                        .unwrap();
5251                    match rx.try_recv() {
5252                        Ok(Message::Text(t)) => assert_eq!(t, "hello"),
5253                        other => panic!("expected Text, got {other:?}"),
5254                    }
5255                    {
5256                        let mut st = conn.state.lock().await;
5257                        let pr = st.pending_requests.remove("id").unwrap();
5258                        pr.completion.send(Ok(serde_json::json!("ok"))).unwrap();
5259                    }
5260                    let resp = fut.await.unwrap().unwrap();
5261                    assert_eq!(resp, serde_json::json!("ok"));
5262                });
5263            }
5264
5265            #[test]
5266            fn async_send_default_connection() {
5267                TOKIO_SHARED_RT.block_on(async {
5268                    let conn = WebsocketConnection::new("c1");
5269                    let (tx, mut rx) = unbounded_channel::<Message>();
5270                    {
5271                        let mut st = conn.state.lock().await;
5272                        st.ws_write_tx = Some(tx);
5273                    }
5274                    let common = WebsocketCommon::new(
5275                        vec![conn.clone()],
5276                        WebsocketMode::Single,
5277                        0,
5278                        None,
5279                        None,
5280                    );
5281                    let fut = common
5282                        .send(
5283                            "msg".into(),
5284                            Some("id".into()),
5285                            true,
5286                            Duration::from_secs(5),
5287                            None,
5288                        )
5289                        .await
5290                        .unwrap()
5291                        .unwrap();
5292                    match rx.try_recv() {
5293                        Ok(Message::Text(t)) => assert_eq!(t, "msg"),
5294                        _ => panic!("no text"),
5295                    }
5296                    {
5297                        let mut st = conn.state.lock().await;
5298                        let pr = st.pending_requests.remove("id").unwrap();
5299                        pr.completion.send(Ok(serde_json::json!(123))).unwrap();
5300                    }
5301                    let resp = fut.await.unwrap().unwrap();
5302                    assert_eq!(resp, serde_json::json!(123));
5303                });
5304            }
5305
5306            #[test]
5307            fn async_send_error_if_no_id() {
5308                TOKIO_SHARED_RT.block_on(async {
5309                    let conn = WebsocketConnection::new("c§");
5310                    let (tx, _rx) = unbounded_channel::<Message>();
5311                    {
5312                        let mut st = conn.state.lock().await;
5313                        st.ws_write_tx = Some(tx);
5314                    }
5315                    let common = WebsocketCommon::new(
5316                        vec![conn.clone()],
5317                        WebsocketMode::Single,
5318                        0,
5319                        None,
5320                        None,
5321                    );
5322                    let err = common
5323                        .send(
5324                            "msg".into(),
5325                            None,
5326                            true,
5327                            Duration::from_secs(1),
5328                            Some(conn.clone()),
5329                        )
5330                        .await
5331                        .unwrap_err();
5332                    assert!(matches!(err, WebsocketError::NotConnected));
5333                });
5334            }
5335
5336            #[test]
5337            fn timeout_rejects_async() {
5338                TOKIO_SHARED_RT.block_on(async {
5339                    pause();
5340                    let conn = WebsocketConnection::new("c1");
5341                    let (tx, _rx) = unbounded_channel::<Message>();
5342                    {
5343                        let mut st = conn.state.lock().await;
5344                        st.ws_write_tx = Some(tx);
5345                    }
5346                    let common = WebsocketCommon::new(
5347                        vec![conn.clone()],
5348                        WebsocketMode::Single,
5349                        0,
5350                        None,
5351                        None,
5352                    );
5353                    let fut = common
5354                        .send(
5355                            "msg".into(),
5356                            Some("id".into()),
5357                            true,
5358                            Duration::from_secs(1),
5359                            Some(conn.clone()),
5360                        )
5361                        .await
5362                        .unwrap()
5363                        .unwrap();
5364                    advance(Duration::from_secs(1)).await;
5365                    let res = fut.await.unwrap();
5366                    assert!(res.is_err(), "expected timeout error");
5367                    assert!(!conn.state.lock().await.pending_requests.contains_key("id"));
5368                });
5369            }
5370
5371            #[test]
5372            fn async_send_errors_if_no_connection_ready() {
5373                TOKIO_SHARED_RT.block_on(async {
5374                    let conn = WebsocketConnection::new("c1");
5375                    let common = WebsocketCommon::new(
5376                        vec![conn.clone()],
5377                        WebsocketMode::Single,
5378                        0,
5379                        None,
5380                        None,
5381                    );
5382                    let err = common
5383                        .send(
5384                            "msg".into(),
5385                            Some("id".into()),
5386                            true,
5387                            Duration::from_secs(1),
5388                            None,
5389                        )
5390                        .await
5391                        .unwrap_err();
5392                    assert!(matches!(err, WebsocketError::NotConnected));
5393                });
5394            }
5395        }
5396    }
5397
5398    mod websocket_api {
5399        use super::*;
5400
5401        mod initialisation {
5402            use super::*;
5403
5404            #[test]
5405            fn new_initializes_common() {
5406                TOKIO_SHARED_RT.block_on(async {
5407                    let conn = WebsocketConnection::new("id");
5408                    let pool = vec![conn.clone()];
5409
5410                    let sig_gen = SignatureGenerator::new(
5411                        Some("api_secret".to_string()),
5412                        None::<PrivateKey>,
5413                        None::<String>,
5414                    );
5415
5416                    let config = ConfigurationWebsocketApi {
5417                        api_key: Some("api_key".to_string()),
5418                        api_secret: Some("api_secret".to_string()),
5419                        private_key: None,
5420                        private_key_passphrase: None,
5421                        ws_url: Some("wss://example".to_string()),
5422                        mode: WebsocketMode::Single,
5423                        reconnect_delay: 1000,
5424                        signature_gen: sig_gen,
5425                        timeout: 500,
5426                        time_unit: None,
5427                        auto_session_relogon: false,
5428                        agent: None,
5429                        user_agent: build_user_agent("product"),
5430                    };
5431
5432                    let api = WebsocketApi::new(config, pool.clone());
5433
5434                    assert_eq!(api.common.connection_pool.len(), 1);
5435                    assert_eq!(api.common.mode, WebsocketMode::Single);
5436
5437                    let flag = *api.is_connecting.lock().await;
5438                    assert!(!flag);
5439                });
5440            }
5441        }
5442
5443        mod connect {
5444            use super::*;
5445
5446            #[test]
5447            fn connect_when_not_connected_establishes() {
5448                TOKIO_SHARED_RT.block_on(async {
5449                    let conn = WebsocketConnection::new("id");
5450                    {
5451                        let mut st = conn.state.lock().await;
5452                        st.ws_write_tx = None;
5453                    }
5454                    let sig = SignatureGenerator::new(
5455                        Some("api_secret".into()),
5456                        None::<PrivateKey>,
5457                        None::<String>,
5458                    );
5459                    let cfg = ConfigurationWebsocketApi {
5460                        api_key: Some("api_key".into()),
5461                        api_secret: Some("api_secret".to_string()),
5462                        private_key: None,
5463                        private_key_passphrase: None,
5464                        ws_url: Some("ws://doesnotexist:1".to_string()),
5465                        mode: WebsocketMode::Single,
5466                        reconnect_delay: 0,
5467                        signature_gen: sig,
5468                        timeout: 10,
5469                        time_unit: None,
5470                        auto_session_relogon: false,
5471                        agent: None,
5472                        user_agent: build_user_agent("product"),
5473                    };
5474                    let api = WebsocketApi::new(cfg, vec![conn.clone()]);
5475                    let res = api.clone().connect().await;
5476                    assert!(!matches!(res, Err(WebsocketError::Timeout)));
5477                });
5478            }
5479
5480            #[test]
5481            fn already_connected_returns_ok() {
5482                TOKIO_SHARED_RT.block_on(async {
5483                    let conn = WebsocketConnection::new("id2");
5484                    let (tx, _) = unbounded_channel();
5485                    {
5486                        let mut st = conn.state.lock().await;
5487                        st.ws_write_tx = Some(tx);
5488                    }
5489                    let sig = SignatureGenerator::new(
5490                        Some("api_secret".to_string()),
5491                        None::<PrivateKey>,
5492                        None::<String>,
5493                    );
5494                    let cfg = ConfigurationWebsocketApi {
5495                        api_key: Some("api_key".to_string()),
5496                        api_secret: Some("api_secret".to_string()),
5497                        private_key: None,
5498                        private_key_passphrase: None,
5499                        ws_url: Some("ws://example.com".to_string()),
5500                        mode: WebsocketMode::Single,
5501                        reconnect_delay: 0,
5502                        signature_gen: sig,
5503                        timeout: 10,
5504                        time_unit: None,
5505                        auto_session_relogon: false,
5506                        agent: None,
5507                        user_agent: build_user_agent("product"),
5508                    };
5509                    let api = WebsocketApi::new(cfg, vec![conn.clone()]);
5510                    let res = api.connect().await;
5511                    assert!(res.is_ok());
5512                });
5513            }
5514
5515            #[test]
5516            fn not_connected_returns_error() {
5517                TOKIO_SHARED_RT.block_on(async {
5518                    let conn = WebsocketConnection::new("id1");
5519                    let sig = SignatureGenerator::new(
5520                        Some("api_secret".to_string()),
5521                        None::<PrivateKey>,
5522                        None::<String>,
5523                    );
5524                    let cfg = ConfigurationWebsocketApi {
5525                        api_key: Some("api_key".to_string()),
5526                        api_secret: Some("api_secret".to_string()),
5527                        private_key: None,
5528                        private_key_passphrase: None,
5529                        ws_url: Some("ws://127.0.0.1:9".to_string()),
5530                        mode: WebsocketMode::Single,
5531                        reconnect_delay: 0,
5532                        signature_gen: sig,
5533                        timeout: 10,
5534                        time_unit: None,
5535                        auto_session_relogon: false,
5536                        agent: None,
5537                        user_agent: build_user_agent("product"),
5538                    };
5539                    let api = WebsocketApi::new(cfg, vec![conn.clone()]);
5540                    let res = api.connect().await;
5541                    assert!(res.is_err());
5542                });
5543            }
5544
5545            #[test]
5546            fn concurrent_calls_both_error_or_ok() {
5547                TOKIO_SHARED_RT.block_on(async {
5548                    let conn = WebsocketConnection::new("id3");
5549                    let sig = SignatureGenerator::new(
5550                        Some("api_secret".to_string()),
5551                        None::<PrivateKey>,
5552                        None::<String>,
5553                    );
5554                    let cfg = ConfigurationWebsocketApi {
5555                        api_key: Some("api_key".to_string()),
5556                        api_secret: Some("api_secret".to_string()),
5557                        private_key: None,
5558                        private_key_passphrase: None,
5559                        ws_url: Some("wss://invalid-domain".to_string()),
5560                        mode: WebsocketMode::Single,
5561                        reconnect_delay: 0,
5562                        signature_gen: sig,
5563                        timeout: 10,
5564                        time_unit: None,
5565                        auto_session_relogon: false,
5566                        agent: None,
5567                        user_agent: build_user_agent("product"),
5568                    };
5569                    let api = WebsocketApi::new(cfg, vec![conn.clone()]);
5570                    let fut1 = tokio::spawn(api.clone().connect());
5571                    let fut2 = tokio::spawn(api.clone().connect());
5572                    let r1 = fut1.await.unwrap();
5573                    let r2 = fut2.await.unwrap();
5574
5575                    assert!(r1.is_err());
5576                    assert!(r2.is_err() || r2.is_ok());
5577                });
5578            }
5579
5580            #[test]
5581            fn pool_failure_is_propagated() {
5582                TOKIO_SHARED_RT.block_on(async {
5583                    let conn = WebsocketConnection::new("w");
5584                    let sig = SignatureGenerator::new(
5585                        Some("api_secret".to_string()),
5586                        None::<PrivateKey>,
5587                        None::<String>,
5588                    );
5589                    let cfg = ConfigurationWebsocketApi {
5590                        api_key: Some("api_key".into()),
5591                        api_secret: Some("api_secret".to_string()),
5592                        private_key: None,
5593                        private_key_passphrase: None,
5594                        ws_url: Some("ws://doesnotexist:1".to_string()),
5595                        mode: WebsocketMode::Single,
5596                        reconnect_delay: 0,
5597                        signature_gen: sig,
5598                        timeout: 10,
5599                        time_unit: None,
5600                        auto_session_relogon: false,
5601                        agent: None,
5602                        user_agent: build_user_agent("product"),
5603                    };
5604                    let api = WebsocketApi::new(cfg, vec![conn.clone()]);
5605                    let res = api.clone().connect().await;
5606                    match res {
5607                        Err(WebsocketError::Handshake(_) | WebsocketError::Timeout) => {}
5608                        _ => panic!("expected handshake or timeout error"),
5609                    }
5610                });
5611            }
5612        }
5613
5614        mod send_message {
5615            use super::*;
5616
5617            #[test]
5618            fn unsigned_message() {
5619                TOKIO_SHARED_RT.block_on(async {
5620                    let api = create_websocket_api(None, None, None);
5621                    let conn = &api.common.connection_pool[0];
5622                    let (tx, mut rx) = unbounded_channel::<Message>();
5623                    {
5624                        let mut st = conn.state.lock().await;
5625                        st.ws_write_tx = Some(tx);
5626                    }
5627
5628                    let fut = tokio::spawn({
5629                        let api = api.clone();
5630                        async move {
5631                            let mut params = BTreeMap::new();
5632                            params.insert("foo".into(), Value::String("bar".into()));
5633                            let send_res = api
5634                                .send_message::<Value>(
5635                                    "method",
5636                                    params,
5637                                    WebsocketMessageSendOptions {
5638                                        with_api_key: false,
5639                                        is_signed: false,
5640                                        ..Default::default()
5641                                    },
5642                                )
5643                                .await
5644                                .unwrap();
5645
5646                            match send_res {
5647                                SendWebsocketMessageResult::Single(resp) => resp,
5648                                SendWebsocketMessageResult::Multiple(_) => {
5649                                    panic!("expected single response")
5650                                }
5651                            }
5652                        }
5653                    });
5654
5655                    let Message::Text(txt) = rx.recv().await.unwrap() else {
5656                        panic!()
5657                    };
5658                    let req: Value = serde_json::from_str(&txt).unwrap();
5659                    assert_eq!(req["method"], "method");
5660                    assert_eq!(req["params"]["foo"], "bar");
5661                    assert!(req["params"].get("apiKey").is_none());
5662                    assert!(req["params"].get("timestamp").is_none());
5663                    assert!(req["params"].get("signature").is_none());
5664
5665                    let id = req["id"].as_str().unwrap().to_string();
5666                    let mut st = conn.state.lock().await;
5667                    let pending = st.pending_requests.remove(&id).unwrap();
5668                    let reply = json!({
5669                        "id": id,
5670                        "result": { "x": 42 },
5671                        "rateLimits": [{ "limit": 7 }]
5672                    });
5673                    pending.completion.send(Ok(reply)).unwrap();
5674
5675                    let resp = fut.await.unwrap();
5676                    let rate_limits = resp.rate_limits.unwrap_or_default();
5677
5678                    assert!(rate_limits.is_empty());
5679                    assert_eq!(resp.raw, json!({"x": 42}));
5680                });
5681            }
5682
5683            #[test]
5684            fn with_api_key_only() {
5685                TOKIO_SHARED_RT.block_on(async {
5686                    let api = create_websocket_api(None, None, None);
5687                    let conn = &api.common.connection_pool[0];
5688                    let (tx, mut rx) = unbounded_channel::<Message>();
5689                    {
5690                        let mut st = conn.state.lock().await;
5691                        st.ws_write_tx = Some(tx);
5692                    }
5693
5694                    let fut = tokio::spawn({
5695                        let api = api.clone();
5696                        async move {
5697                            let params = BTreeMap::new();
5698                            let send_res = api
5699                                .send_message::<Value>(
5700                                    "method",
5701                                    params,
5702                                    WebsocketMessageSendOptions {
5703                                        with_api_key: true,
5704                                        is_signed: false,
5705                                        ..Default::default()
5706                                    },
5707                                )
5708                                .await
5709                                .unwrap();
5710
5711                            match send_res {
5712                                SendWebsocketMessageResult::Single(resp) => resp,
5713                                SendWebsocketMessageResult::Multiple(_) => {
5714                                    panic!("expected single response")
5715                                }
5716                            }
5717                        }
5718                    });
5719
5720                    let Message::Text(txt) = rx.recv().await.unwrap() else {
5721                        panic!()
5722                    };
5723                    let req: Value = serde_json::from_str(&txt).unwrap();
5724                    assert_eq!(req["params"]["apiKey"], "api_key");
5725
5726                    let id = req["id"].as_str().unwrap().to_string();
5727                    let mut st = conn.state.lock().await;
5728                    let pending = st.pending_requests.remove(&id).unwrap();
5729                    pending
5730                        .completion
5731                        .send(Ok(json!({
5732                            "id": id,
5733                            "result": {},
5734                            "rateLimits": []
5735                        })))
5736                        .unwrap();
5737
5738                    let resp = fut.await.unwrap();
5739
5740                    assert_eq!(resp.raw, json!({}));
5741                    assert!(st.pending_requests.is_empty());
5742                });
5743            }
5744
5745            #[test]
5746            fn signed_message_has_timestamp_and_signature() {
5747                TOKIO_SHARED_RT.block_on(async {
5748                    let api = create_websocket_api(None, None, None);
5749                    let conn = &api.common.connection_pool[0];
5750                    let (tx, mut rx) = unbounded_channel::<Message>();
5751                    {
5752                        let mut st = conn.state.lock().await;
5753                        st.ws_write_tx = Some(tx);
5754                    }
5755
5756                    let fut = tokio::spawn({
5757                        let api = api.clone();
5758                        async move {
5759                            let mut params = BTreeMap::new();
5760                            params.insert("foo".into(), Value::String("bar".into()));
5761                            let send_res = api
5762                                .send_message::<Value>(
5763                                    "method",
5764                                    params,
5765                                    WebsocketMessageSendOptions {
5766                                        with_api_key: true,
5767                                        is_signed: true,
5768                                        ..Default::default()
5769                                    },
5770                                )
5771                                .await
5772                                .unwrap();
5773
5774                            match send_res {
5775                                SendWebsocketMessageResult::Single(resp) => resp,
5776                                SendWebsocketMessageResult::Multiple(_) => {
5777                                    panic!("expected single response")
5778                                }
5779                            }
5780                        }
5781                    });
5782
5783                    let Message::Text(txt) = rx.recv().await.unwrap() else {
5784                        panic!()
5785                    };
5786                    let req: Value = serde_json::from_str(&txt).unwrap();
5787                    let p = &req["params"];
5788                    assert_eq!(p["apiKey"], "api_key");
5789                    assert!(p["timestamp"].is_number());
5790                    assert!(p["signature"].is_string());
5791
5792                    let id = req["id"].as_str().unwrap().to_string();
5793                    let mut st = conn.state.lock().await;
5794                    let pending = st.pending_requests.remove(&id).unwrap();
5795                    pending
5796                        .completion
5797                        .send(Ok(json!({
5798                            "id": id,
5799                            "result": { "ok": true },
5800                            "rateLimits": []
5801                        })))
5802                        .unwrap();
5803
5804                    let resp = fut.await.unwrap();
5805                    assert_eq!(resp.raw, json!({ "ok": true }));
5806                });
5807            }
5808
5809            #[test]
5810            fn multi_session_logon() {
5811                TOKIO_SHARED_RT.block_on(async {
5812                    let api = create_websocket_api(None, Some(WebsocketMode::Pool(2)), None);
5813                    let conn0 = &api.common.connection_pool[0];
5814                    let conn1 = &api.common.connection_pool[1];
5815
5816                    let (tx0, mut rx0) = unbounded_channel::<Message>();
5817                    let (tx1, mut rx1) = unbounded_channel::<Message>();
5818                    {
5819                        let mut st0 = conn0.state.lock().await;
5820                        st0.ws_write_tx = Some(tx0);
5821                    }
5822                    {
5823                        let mut st1 = conn1.state.lock().await;
5824                        st1.ws_write_tx = Some(tx1);
5825                    }
5826
5827                    let fut = tokio::spawn({
5828                        let api = api.clone();
5829                        async move {
5830                            let params = BTreeMap::new();
5831                            let send_res = api
5832                                .send_message::<Value>(
5833                                    "method",
5834                                    params,
5835                                    WebsocketMessageSendOptions {
5836                                        is_session_logon: Some(true),
5837                                        ..Default::default()
5838                                    },
5839                                )
5840                                .await
5841                                .unwrap();
5842
5843                            match send_res {
5844                                SendWebsocketMessageResult::Multiple(v) => v,
5845                                SendWebsocketMessageResult::Single(_) => {
5846                                    panic!("expected multiple responses")
5847                                }
5848                            }
5849                        }
5850                    });
5851
5852                    let Message::Text(txt0) = rx0.recv().await.unwrap() else {
5853                        panic!()
5854                    };
5855                    let Message::Text(txt1) = rx1.recv().await.unwrap() else {
5856                        panic!()
5857                    };
5858                    let req0: Value = serde_json::from_str(&txt0).unwrap();
5859                    let req1: Value = serde_json::from_str(&txt1).unwrap();
5860                    assert_eq!(req0["method"], "method");
5861                    assert_eq!(req1["method"], "method");
5862                    let id = req0["id"].as_str().unwrap().to_string();
5863                    assert_eq!(req1["id"].as_str().unwrap(), &id);
5864
5865                    {
5866                        let mut st0 = conn0.state.lock().await;
5867                        let pending0 = st0.pending_requests.remove(&id).unwrap();
5868                        pending0
5869                            .completion
5870                            .send(Ok(json!({
5871                                "id": id,
5872                                "result": { "ok": true },
5873                                "rateLimits": []
5874                            })))
5875                            .unwrap();
5876                    }
5877                    {
5878                        let mut st1 = conn1.state.lock().await;
5879                        let pending1 = st1.pending_requests.remove(&id).unwrap();
5880                        pending1
5881                            .completion
5882                            .send(Ok(json!({
5883                                "id": id,
5884                                "result": { "ok": true },
5885                                "rateLimits": []
5886                            })))
5887                            .unwrap();
5888                    }
5889
5890                    let results = fut.await.unwrap();
5891                    assert_eq!(results.len(), 2);
5892
5893                    for conn in &api.common.connection_pool {
5894                        let st = conn.state.lock().await;
5895                        assert!(st.is_session_logged_on, "should be logged out");
5896                        assert!(st.session_logon_req.is_some(), "req cleared");
5897
5898                        // let req = st
5899                        //     .session_logon_req
5900                        //     .as_ref()
5901                        //     .expect("session_logon_req should be Some(_)");
5902                        // assert_eq!(req.method, "method");
5903                        // let mut expected = BTreeMap::new();
5904                        // expected.insert("ok".to_string(), Value::Bool(true));
5905                        // assert_eq!(
5906                        //     req.payload, expected,
5907                        //     "stored payload should be {{ \"ok\": true }}"
5908                        // );
5909                        // assert!(
5910                        //     req.options.is_session_logon.unwrap_or(false),
5911                        //     "expected options.is_session_logon = true"
5912                        // );
5913                    }
5914                });
5915            }
5916
5917            #[test]
5918            fn multi_session_logout() {
5919                TOKIO_SHARED_RT.block_on(async {
5920                    let api = create_websocket_api(None, Some(WebsocketMode::Pool(2)), None);
5921
5922                    for conn in &api.common.connection_pool {
5923                        let (tx, _rx) = unbounded_channel::<Message>();
5924                        let mut st = conn.state.lock().await;
5925                        st.ws_write_tx = Some(tx);
5926                        st.is_session_logged_on = true;
5927                        st.session_logon_req = Some(WebsocketSessionLogonReq {
5928                            method: "method".into(),
5929                            payload: BTreeMap::new(),
5930                            options: WebsocketMessageSendOptions::default(),
5931                        });
5932                    }
5933
5934                    let mut rxs = Vec::new();
5935                    for conn in &api.common.connection_pool {
5936                        let rx = {
5937                            let (tx, rx) = unbounded_channel::<Message>();
5938                            conn.state.lock().await.ws_write_tx = Some(tx);
5939                            rx
5940                        };
5941                        rxs.push(rx);
5942                    }
5943
5944                    let fut = tokio::spawn({
5945                        let api = api.clone();
5946                        async move {
5947                            let send_res = api
5948                                .send_message::<Value>(
5949                                    "method",
5950                                    BTreeMap::new(),
5951                                    WebsocketMessageSendOptions {
5952                                        is_signed: false,
5953                                        with_api_key: false,
5954                                        is_session_logout: Some(true),
5955                                        ..Default::default()
5956                                    },
5957                                )
5958                                .await
5959                                .unwrap();
5960
5961                            match send_res {
5962                                SendWebsocketMessageResult::Multiple(v) => v,
5963                                SendWebsocketMessageResult::Single(_) => panic!("expected multi"),
5964                            }
5965                        }
5966                    });
5967
5968                    let mut ids = Vec::new();
5969                    for mut rx in rxs {
5970                        let Message::Text(txt) = rx.recv().await.unwrap() else {
5971                            panic!()
5972                        };
5973                        let req: Value = serde_json::from_str(&txt).unwrap();
5974                        assert_eq!(req["method"], "method");
5975                        ids.push(req["id"].as_str().unwrap().to_string());
5976                    }
5977
5978                    assert_eq!(ids[0], ids[1]);
5979
5980                    for conn in &api.common.connection_pool {
5981                        let id = &ids[0];
5982                        let mut st = conn.state.lock().await;
5983                        let pending = st.pending_requests.remove(id).unwrap();
5984                        pending
5985                            .completion
5986                            .send(Ok(json!({
5987                                "id": id,
5988                                "result": {},
5989                                "rateLimits": []
5990                            })))
5991                            .unwrap();
5992                    }
5993
5994                    let results = fut.await.unwrap();
5995                    assert_eq!(results.len(), 2);
5996
5997                    for conn in &api.common.connection_pool {
5998                        let st = conn.state.lock().await;
5999                        assert!(!st.is_session_logged_on, "should be logged out");
6000                        assert!(st.session_logon_req.is_none(), "req cleared");
6001                    }
6002                });
6003            }
6004
6005            #[test]
6006            fn skip_signature_when_logged_on_and_auto_relogon() {
6007                TOKIO_SHARED_RT.block_on(async {
6008                    let api = create_websocket_api(None, Some(WebsocketMode::Single), None);
6009                    let conn = &api.common.connection_pool[0];
6010                    {
6011                        let mut st = conn.state.lock().await;
6012                        st.ws_write_tx = Some(unbounded_channel::<Message>().0);
6013                        st.is_session_logged_on = true;
6014                    }
6015
6016                    let mut rx;
6017                    {
6018                        let mut st = conn.state.lock().await;
6019                        let (tx, new_rx) = unbounded_channel::<Message>();
6020                        st.ws_write_tx = Some(tx);
6021                        rx = new_rx;
6022                    }
6023
6024                    let fut = tokio::spawn({
6025                        let api = api.clone();
6026                        async move {
6027                            let send_res = api
6028                                .send_message::<Value>(
6029                                    "method",
6030                                    BTreeMap::new(),
6031                                    WebsocketMessageSendOptions {
6032                                        is_signed: true,
6033                                        ..Default::default()
6034                                    },
6035                                )
6036                                .await
6037                                .unwrap();
6038
6039                            match send_res {
6040                                SendWebsocketMessageResult::Single(resp) => resp,
6041                                SendWebsocketMessageResult::Multiple(_) => {
6042                                    panic!("expected single")
6043                                }
6044                            }
6045                        }
6046                    });
6047
6048                    let Message::Text(txt) = rx.recv().await.unwrap() else {
6049                        panic!()
6050                    };
6051                    let req: Value = serde_json::from_str(&txt).unwrap();
6052                    let p = &req["params"];
6053                    assert!(p.get("timestamp").is_some());
6054                    assert!(p.get("signature").is_none());
6055
6056                    let id = req["id"].as_str().unwrap();
6057                    let mut st = conn.state.lock().await;
6058                    let pending = st.pending_requests.remove(id).unwrap();
6059                    pending
6060                        .completion
6061                        .send(Ok(json!({
6062                            "id": id,
6063                            "result": {},
6064                            "rateLimits": []
6065                        })))
6066                        .unwrap();
6067
6068                    let resp = fut.await.unwrap();
6069                    assert_eq!(resp.raw, json!({}));
6070                });
6071            }
6072
6073            #[test]
6074            fn include_signature_when_logged_on_and_no_auto_relogon() {
6075                TOKIO_SHARED_RT.block_on(async {
6076                    let api = create_websocket_api(None, Some(WebsocketMode::Single), Some(false));
6077                    let conn = &api.common.connection_pool[0];
6078                    {
6079                        let mut st = conn.state.lock().await;
6080                        st.ws_write_tx = Some(unbounded_channel::<Message>().0);
6081                        st.is_session_logged_on = true;
6082                    }
6083
6084                    let mut rx;
6085                    {
6086                        let mut st = conn.state.lock().await;
6087                        let (tx, new_rx) = unbounded_channel::<Message>();
6088                        st.ws_write_tx = Some(tx);
6089                        rx = new_rx;
6090                    }
6091
6092                    let fut = tokio::spawn({
6093                        let api = api.clone();
6094                        async move {
6095                            let send_res = api
6096                                .send_message::<Value>(
6097                                    "method",
6098                                    BTreeMap::new(),
6099                                    WebsocketMessageSendOptions {
6100                                        is_signed: true,
6101                                        ..Default::default()
6102                                    },
6103                                )
6104                                .await
6105                                .unwrap();
6106
6107                            match send_res {
6108                                SendWebsocketMessageResult::Single(resp) => resp,
6109                                SendWebsocketMessageResult::Multiple(_) => {
6110                                    panic!("expected single")
6111                                }
6112                            }
6113                        }
6114                    });
6115
6116                    let Message::Text(txt) = rx.recv().await.unwrap() else {
6117                        panic!()
6118                    };
6119                    let req: Value = serde_json::from_str(&txt).unwrap();
6120                    let p = &req["params"];
6121                    assert!(p.get("timestamp").is_some());
6122                    assert!(p.get("signature").is_some());
6123
6124                    let id = req["id"].as_str().unwrap();
6125                    let mut st = conn.state.lock().await;
6126                    let pending = st.pending_requests.remove(id).unwrap();
6127                    pending
6128                        .completion
6129                        .send(Ok(json!({
6130                            "id": id,
6131                            "result": {},
6132                            "rateLimits": []
6133                        })))
6134                        .unwrap();
6135
6136                    let resp = fut.await.unwrap();
6137                    assert_eq!(resp.raw, json!({}));
6138                });
6139            }
6140
6141            #[test]
6142            fn error_if_not_connected() {
6143                TOKIO_SHARED_RT.block_on(async {
6144                    let api = create_websocket_api(None, None, None);
6145                    let conn = &api.common.connection_pool[0];
6146                    {
6147                        let mut st = conn.state.lock().await;
6148                        st.ws_write_tx = None;
6149                    }
6150                    let params = BTreeMap::new();
6151                    let err = api
6152                        .send_message::<Value>(
6153                            "method",
6154                            params,
6155                            WebsocketMessageSendOptions {
6156                                with_api_key: false,
6157                                is_signed: false,
6158                                ..Default::default()
6159                            },
6160                        )
6161                        .await
6162                        .unwrap_err();
6163                    matches!(err, WebsocketError::NotConnected);
6164                });
6165            }
6166        }
6167
6168        mod prepare_url {
6169            use super::*;
6170
6171            #[test]
6172            fn no_time_unit() {
6173                TOKIO_SHARED_RT.block_on(async {
6174                    let api = create_websocket_api(None, None, None);
6175                    let url = "wss://example.com/ws".to_string();
6176                    assert_eq!(api.prepare_url(&url), url);
6177                });
6178            }
6179
6180            #[test]
6181            fn appends_time_unit() {
6182                TOKIO_SHARED_RT.block_on(async {
6183                    let api = create_websocket_api(Some(TimeUnit::Millisecond), None, None);
6184                    let base = "wss://example.com/ws".to_string();
6185                    let got = api.prepare_url(&base);
6186                    assert_eq!(got, format!("{base}?timeUnit=millisecond"));
6187                });
6188            }
6189
6190            #[test]
6191            fn handles_existing_query() {
6192                TOKIO_SHARED_RT.block_on(async {
6193                    let api = create_websocket_api(Some(TimeUnit::Microsecond), None, None);
6194                    let base = "wss://example.com/ws?foo=bar".to_string();
6195                    let got = api.prepare_url(&base);
6196                    assert_eq!(got, format!("{base}&timeUnit=microsecond"));
6197                });
6198            }
6199        }
6200
6201        mod on_open {
6202            use super::*;
6203
6204            fn create_websocket_api_and_conn() -> (Arc<WebsocketApi>, Arc<WebsocketConnection>) {
6205                let sig_gen = SignatureGenerator::new(
6206                    Some("api_secret".to_string()),
6207                    None::<_>,
6208                    None::<String>,
6209                );
6210                let config = ConfigurationWebsocketApi {
6211                    api_key: Some("api_key".to_string()),
6212                    api_secret: Some("api_secret".to_string()),
6213                    private_key: None,
6214                    private_key_passphrase: None,
6215                    ws_url: Some("wss://example".to_string()),
6216                    mode: WebsocketMode::Single,
6217                    reconnect_delay: 0,
6218                    signature_gen: sig_gen,
6219                    timeout: 1000,
6220                    time_unit: None,
6221                    auto_session_relogon: true,
6222                    agent: None,
6223                    user_agent: build_user_agent("product"),
6224                };
6225                let conn = WebsocketConnection::new("test-conn");
6226                let api = WebsocketApi::new(config, vec![conn.clone()]);
6227                (api, conn)
6228            }
6229
6230            #[test]
6231            fn session_relogon_on_open() {
6232                TOKIO_SHARED_RT.block_on(async {
6233                    let (api, conn) = create_websocket_api_and_conn();
6234
6235                    let req = WebsocketSessionLogonReq {
6236                        method: "method".into(),
6237                        payload: {
6238                            let mut m = BTreeMap::new();
6239                            m.insert("foo".into(), Value::String("bar".into()));
6240                            m
6241                        },
6242                        options: WebsocketMessageSendOptions {
6243                            with_api_key: true,
6244                            is_signed: true,
6245                            is_session_logon: Some(true),
6246                            ..Default::default()
6247                        },
6248                    };
6249
6250                    let (tx, mut rx) = unbounded_channel::<Message>();
6251                    {
6252                        let mut st = conn.state.lock().await;
6253                        st.session_logon_req = Some(req.clone());
6254                        st.is_session_logged_on = false;
6255                        st.ws_write_tx = Some(tx);
6256                    }
6257
6258                    api.on_open("wss://example".to_string(), conn.clone()).await;
6259
6260                    let Message::Text(raw) = rx.recv().await.unwrap() else {
6261                        panic!("expected a Text message");
6262                    };
6263                    let msg: Value = serde_json::from_str(&raw).unwrap();
6264                    assert_eq!(msg["method"], "method");
6265                    assert_eq!(msg["params"]["foo"], "bar");
6266
6267                    let id = msg["id"].as_str().unwrap().to_string();
6268                    {
6269                        let mut st = conn.state.lock().await;
6270                        let pending = st.pending_requests.remove(&id).expect("pending request");
6271                        pending
6272                            .completion
6273                            .send(Ok(json!({
6274                                "id": id,
6275                                "result": {},
6276                                "rateLimits": []
6277                            })))
6278                            .unwrap();
6279                    }
6280
6281                    sleep(Duration::from_millis(10)).await;
6282
6283                    let st = conn.state.lock().await;
6284                    assert!(st.is_session_logged_on, "should now be logged on");
6285                });
6286            }
6287
6288            #[test]
6289            fn no_relogon_if_already_logged_on() {
6290                TOKIO_SHARED_RT.block_on(async {
6291                    let (api, conn) = create_websocket_api_and_conn();
6292
6293                    let req = WebsocketSessionLogonReq {
6294                        method: "method".into(),
6295                        payload: BTreeMap::new(),
6296                        options: WebsocketMessageSendOptions {
6297                            is_session_logon: Some(true),
6298                            ..Default::default()
6299                        },
6300                    };
6301
6302                    let (tx, mut rx) = unbounded_channel::<Message>();
6303                    {
6304                        let mut st = conn.state.lock().await;
6305                        st.session_logon_req = Some(req);
6306                        st.is_session_logged_on = true;
6307                        st.ws_write_tx = Some(tx);
6308                    }
6309
6310                    api.on_open("wss://example".to_string(), conn.clone()).await;
6311
6312                    assert!(rx.try_recv().is_err(), "no re‐logon when already on");
6313
6314                    let st = conn.state.lock().await;
6315                    assert!(st.is_session_logged_on);
6316                });
6317            }
6318
6319            #[test]
6320            fn session_relogon_fails_gracefully() {
6321                TOKIO_SHARED_RT.block_on(async {
6322                    let (api, conn) = create_websocket_api_and_conn();
6323
6324                    let req = WebsocketSessionLogonReq {
6325                        method: "method".into(),
6326                        payload: {
6327                            let mut m = BTreeMap::new();
6328                            m.insert("x".into(), Value::Number(1.into()));
6329                            m
6330                        },
6331                        options: WebsocketMessageSendOptions {
6332                            is_session_logon: Some(true),
6333                            ..Default::default()
6334                        },
6335                    };
6336                    {
6337                        let mut st = conn.state.lock().await;
6338                        st.session_logon_req = Some(req);
6339                        st.is_session_logged_on = false;
6340                        st.ws_write_tx = None;
6341                    }
6342
6343                    api.on_open("wss://example".into(), conn.clone()).await;
6344
6345                    let st = conn.state.lock().await;
6346                    assert!(
6347                        !st.is_session_logged_on,
6348                        "should remain logged‐off on failure"
6349                    );
6350                });
6351            }
6352
6353            #[test]
6354            fn session_relogon_noop_when_no_req() {
6355                TOKIO_SHARED_RT.block_on(async {
6356                    let (api, conn) = create_websocket_api_and_conn();
6357
6358                    {
6359                        let mut st = conn.state.lock().await;
6360                        st.session_logon_req = None;
6361                        st.is_session_logged_on = false;
6362                        st.ws_write_tx = Some(unbounded_channel::<Message>().0);
6363                    }
6364
6365                    api.on_open("wss://example".into(), conn.clone()).await;
6366
6367                    let st = conn.state.lock().await;
6368                    assert!(!st.is_session_logged_on, "still logged‐off");
6369                });
6370            }
6371        }
6372
6373        mod on_message {
6374            use super::*;
6375
6376            fn create_websocket_api_and_conn() -> (Arc<WebsocketApi>, Arc<WebsocketConnection>) {
6377                let sig_gen = SignatureGenerator::new(
6378                    Some("api_secret".to_string()),
6379                    None::<_>,
6380                    None::<String>,
6381                );
6382                let config = ConfigurationWebsocketApi {
6383                    api_key: Some("api_key".to_string()),
6384                    api_secret: Some("api_secret".to_string()),
6385                    private_key: None,
6386                    private_key_passphrase: None,
6387                    ws_url: Some("wss://example".to_string()),
6388                    mode: WebsocketMode::Single,
6389                    reconnect_delay: 0,
6390                    signature_gen: sig_gen,
6391                    timeout: 1000,
6392                    time_unit: None,
6393                    auto_session_relogon: false,
6394                    agent: None,
6395                    user_agent: build_user_agent("product"),
6396                };
6397                let conn = WebsocketConnection::new("test");
6398                let api = WebsocketApi::new(config, vec![conn.clone()]);
6399                (api, conn)
6400            }
6401
6402            #[test]
6403            fn resolves_pending_and_removes_request() {
6404                TOKIO_SHARED_RT.block_on(async {
6405                    let (api, conn) = create_websocket_api_and_conn();
6406                    let (tx, rx) = oneshot::channel();
6407                    {
6408                        let mut st = conn.state.lock().await;
6409                        st.pending_requests
6410                            .insert("id1".to_string(), PendingRequest { completion: tx });
6411                    }
6412                    let msg = json!({"id":"id1","status":200,"foo":"bar"});
6413                    api.on_message(msg.to_string(), conn.clone()).await;
6414                    let got = rx.await.unwrap().unwrap();
6415                    assert_eq!(got, msg);
6416                    let st = conn.state.lock().await;
6417                    assert!(!st.pending_requests.contains_key("id1"));
6418                });
6419            }
6420
6421            #[test]
6422            fn uses_result_when_present() {
6423                TOKIO_SHARED_RT.block_on(async {
6424                    let (api, conn) = create_websocket_api_and_conn();
6425                    let (tx, rx) = oneshot::channel();
6426                    {
6427                        let mut st = conn.state.lock().await;
6428                        st.pending_requests
6429                            .insert("id1".to_string(), PendingRequest { completion: tx });
6430                    }
6431                    let msg = json!({
6432                        "id": "id1",
6433                        "status": 200,
6434                        "response": [1,2],
6435                        "result": {"a":1}
6436                    });
6437                    api.on_message(msg.to_string(), conn.clone()).await;
6438                    let got = rx.await.unwrap().unwrap();
6439                    assert_eq!(got.get("result").unwrap(), &json!({"a":1}));
6440                });
6441            }
6442
6443            #[test]
6444            fn uses_response_when_no_result() {
6445                TOKIO_SHARED_RT.block_on(async {
6446                    let (api, conn) = create_websocket_api_and_conn();
6447                    let (tx, rx) = oneshot::channel();
6448                    {
6449                        let mut st = conn.state.lock().await;
6450                        st.pending_requests
6451                            .insert("id1".to_string(), PendingRequest { completion: tx });
6452                    }
6453                    let msg = json!({
6454                        "id": "id1",
6455                        "status": 200,
6456                        "response": ["ok"]
6457                    });
6458                    api.on_message(msg.to_string(), conn.clone()).await;
6459                    let got = rx.await.unwrap().unwrap();
6460                    assert_eq!(got.get("response").unwrap(), &json!(["ok"]));
6461                });
6462            }
6463
6464            #[test]
6465            fn errors_for_status_ge_400() {
6466                TOKIO_SHARED_RT.block_on(async {
6467                    let (api, conn) = create_websocket_api_and_conn();
6468                    let (tx, rx) = oneshot::channel();
6469                    {
6470                        let mut st = conn.state.lock().await;
6471                        st.pending_requests
6472                            .insert("bad".to_string(), PendingRequest { completion: tx });
6473                    }
6474                    let err_obj = json!({"code":123,"msg":"oops"});
6475                    let msg = json!({"id":"bad","status":500,"error":err_obj});
6476                    api.on_message(msg.to_string(), conn.clone()).await;
6477                    match rx.await.unwrap() {
6478                        Err(WebsocketError::ResponseError { code, message }) => {
6479                            assert_eq!(code, 123);
6480                            assert_eq!(message, "oops");
6481                        }
6482                        other => panic!("expected ResponseError, got {other:?}"),
6483                    }
6484                    let st = conn.state.lock().await;
6485                    assert!(!st.pending_requests.contains_key("bad"));
6486                });
6487            }
6488
6489            #[test]
6490            fn ignores_unknown_id() {
6491                TOKIO_SHARED_RT.block_on(async {
6492                    let (api, conn) = create_websocket_api_and_conn();
6493                    let msg = json!({"id":"nope","status":200});
6494                    api.on_message(msg.to_string(), conn.clone()).await;
6495                    let st = conn.state.lock().await;
6496                    assert!(st.pending_requests.is_empty());
6497                });
6498            }
6499
6500            #[test]
6501            fn parse_error_ignored() {
6502                TOKIO_SHARED_RT.block_on(async {
6503                    let (api, conn) = create_websocket_api_and_conn();
6504                    api.on_message("not json".to_string(), conn.clone()).await;
6505                    let st = conn.state.lock().await;
6506                    assert!(st.pending_requests.is_empty());
6507                });
6508            }
6509
6510            #[test]
6511            fn error_status_sends_error() {
6512                TOKIO_SHARED_RT.block_on(async {
6513                    let (api, conn) = create_websocket_api_and_conn();
6514                    let (tx, rx) = oneshot::channel();
6515                    {
6516                        let mut st = conn.state.lock().await;
6517                        st.pending_requests
6518                            .insert("err".to_string(), PendingRequest { completion: tx });
6519                    }
6520                    let msg = json!({
6521                        "id": "err",
6522                        "status": 500,
6523                        "error": { "code": 42, "msg": "Bad!" }
6524                    });
6525                    api.on_message(msg.to_string(), conn.clone()).await;
6526                    match rx.await.unwrap() {
6527                        Err(WebsocketError::ResponseError { code, message }) => {
6528                            assert_eq!(code, 42);
6529                            assert_eq!(message, "Bad!");
6530                        }
6531                        other => panic!("expected ResponseError, got {other:?}"),
6532                    }
6533                });
6534            }
6535
6536            #[test]
6537            fn unknown_id_logs_warning_and_leaves_pending() {
6538                TOKIO_SHARED_RT.block_on(async {
6539                    let (api, conn) = create_websocket_api_and_conn();
6540                    {
6541                        let mut st = conn.state.lock().await;
6542                        st.pending_requests.insert(
6543                            "keep".to_string(),
6544                            PendingRequest {
6545                                completion: oneshot::channel().0,
6546                            },
6547                        );
6548                    }
6549                    api.on_message(
6550                        json!({ "id": "foo", "status": 200, "result": 1 }).to_string(),
6551                        conn.clone(),
6552                    )
6553                    .await;
6554                    let st = conn.state.lock().await;
6555                    assert!(st.pending_requests.contains_key("keep"));
6556                });
6557            }
6558        }
6559    }
6560
6561    mod websocket_streams {
6562        use super::*;
6563
6564        mod initialisation {
6565            use super::*;
6566
6567            #[test]
6568            fn new_initializes_fields() {
6569                TOKIO_SHARED_RT.block_on(async {
6570                    let config = ConfigurationWebsocketStreams {
6571                        ws_url: Some("wss://example".to_string()),
6572                        mode: WebsocketMode::Pool(2),
6573                        reconnect_delay: 500,
6574                        time_unit: None,
6575                        agent: None,
6576                        user_agent: build_user_agent("product"),
6577                    };
6578                    let conn1 = WebsocketConnection::new("c1");
6579                    let conn2 = WebsocketConnection::new("c2");
6580                    let api = WebsocketStreams::new(
6581                        config.clone(),
6582                        vec![conn1.clone(), conn2.clone()],
6583                        vec![],
6584                    );
6585
6586                    assert_eq!(api.common.connection_pool.len(), 2);
6587                    assert!(Arc::ptr_eq(&api.common.connection_pool[0], &conn1));
6588                    assert!(Arc::ptr_eq(&api.common.connection_pool[1], &conn2));
6589                    assert_eq!(api.configuration.ws_url, Some("wss://example".to_string()));
6590                    let flag = api.is_connecting.lock().await;
6591                    assert!(!*flag);
6592                });
6593            }
6594
6595            #[test]
6596            fn new_expands_pool_when_url_paths_present() {
6597                TOKIO_SHARED_RT.block_on(async {
6598                    let config = ConfigurationWebsocketStreams {
6599                        ws_url: Some("wss://example".to_string()),
6600                        mode: WebsocketMode::Pool(2),
6601                        reconnect_delay: 500,
6602                        time_unit: None,
6603                        agent: None,
6604                        user_agent: build_user_agent("product"),
6605                    };
6606
6607                    let conn1 = WebsocketConnection::new("c1");
6608                    let conn2 = WebsocketConnection::new("c2");
6609
6610                    let api = WebsocketStreams::new(
6611                        config,
6612                        vec![conn1.clone(), conn2.clone()],
6613                        vec!["path1".to_string(), "path2".to_string()],
6614                    );
6615
6616                    assert_eq!(api.common.connection_pool.len(), 4);
6617                    assert!(Arc::ptr_eq(&api.common.connection_pool[0], &conn1));
6618                    assert!(Arc::ptr_eq(&api.common.connection_pool[1], &conn2));
6619                });
6620            }
6621
6622            #[test]
6623            fn new_does_not_expand_pool_when_already_sized_for_url_paths() {
6624                TOKIO_SHARED_RT.block_on(async {
6625                    let config = ConfigurationWebsocketStreams {
6626                        ws_url: Some("wss://example".to_string()),
6627                        mode: WebsocketMode::Pool(2),
6628                        reconnect_delay: 500,
6629                        time_unit: None,
6630                        agent: None,
6631                        user_agent: build_user_agent("product"),
6632                    };
6633
6634                    let conns = vec![
6635                        WebsocketConnection::new("c1"),
6636                        WebsocketConnection::new("c2"),
6637                        WebsocketConnection::new("c3"),
6638                        WebsocketConnection::new("c4"),
6639                    ];
6640
6641                    let api = WebsocketStreams::new(
6642                        config,
6643                        conns.clone(),
6644                        vec!["path1".to_string(), "path2".to_string()],
6645                    );
6646
6647                    assert_eq!(api.common.connection_pool.len(), 4);
6648                    for (i, c) in conns.iter().enumerate() {
6649                        assert!(Arc::ptr_eq(&api.common.connection_pool[i], c));
6650                    }
6651                });
6652            }
6653        }
6654
6655        mod connect {
6656            use super::*;
6657
6658            #[test]
6659            fn establishes_successfully() {
6660                TOKIO_SHARED_RT.block_on(async {
6661                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
6662                    let port = listener.local_addr().unwrap().port();
6663
6664                    tokio::spawn(async move {
6665                        for _ in 0..2 {
6666                            if let Ok((stream, _)) = listener.accept().await {
6667                                let mut ws = accept_async(stream).await.unwrap();
6668                                ws.close(None).await.ok();
6669                            }
6670                        }
6671                    });
6672
6673                    let create_websocket_streams = |ws_url: &str| {
6674                        let c1 = WebsocketConnection::new("c1");
6675                        let c2 = WebsocketConnection::new("c2");
6676                        let config = ConfigurationWebsocketStreams {
6677                            ws_url: Some(ws_url.to_string()),
6678                            mode: WebsocketMode::Pool(2),
6679                            reconnect_delay: 500,
6680                            time_unit: None,
6681                            agent: None,
6682                            user_agent: build_user_agent("product"),
6683                        };
6684                        WebsocketStreams::new(config, vec![c1, c2], vec![])
6685                    };
6686
6687                    let url = format!("ws://127.0.0.1:{port}");
6688                    let ws = create_websocket_streams(&url);
6689
6690                    let res = ws.connect(vec!["stream1".into()]).await;
6691                    assert!(res.is_ok());
6692                });
6693            }
6694
6695            #[test]
6696            fn establishes_successfully_with_url_paths() {
6697                TOKIO_SHARED_RT.block_on(async {
6698                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
6699                    let addr = listener.local_addr().unwrap();
6700
6701                    tokio::spawn(async move {
6702                        for _ in 0..4 {
6703                            if let Ok((stream, _)) = listener.accept().await {
6704                                let mut ws = accept_async(stream).await.unwrap();
6705                                ws.close(None).await.ok();
6706                            }
6707                        }
6708                    });
6709
6710                    let config = ConfigurationWebsocketStreams {
6711                        ws_url: Some(format!("ws://{}", addr)),
6712                        mode: WebsocketMode::Pool(2),
6713                        reconnect_delay: 500,
6714                        time_unit: None,
6715                        agent: None,
6716                        user_agent: build_user_agent("product"),
6717                    };
6718
6719                    let ws = WebsocketStreams::new(
6720                        config,
6721                        vec![],
6722                        vec!["path1".to_string(), "path2".to_string()],
6723                    );
6724
6725                    let res = ws.clone().connect(vec!["stream1".into()]).await;
6726                    assert!(res.is_ok());
6727                    assert_eq!(ws.common.connection_pool.len(), 4);
6728                });
6729            }
6730
6731            #[test]
6732            fn connect_sets_url_path_on_connections_when_url_paths_present() {
6733                TOKIO_SHARED_RT.block_on(async {
6734                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
6735                    let addr = listener.local_addr().unwrap();
6736
6737                    tokio::spawn(async move {
6738                        for _ in 0..4 {
6739                            if let Ok((stream, _)) = listener.accept().await {
6740                                let mut ws = accept_async(stream).await.unwrap();
6741                                ws.close(None).await.ok();
6742                            }
6743                        }
6744                    });
6745
6746                    let config = ConfigurationWebsocketStreams {
6747                        ws_url: Some(format!("ws://{}", addr)),
6748                        mode: WebsocketMode::Pool(2),
6749                        reconnect_delay: 500,
6750                        time_unit: None,
6751                        agent: None,
6752                        user_agent: build_user_agent("product"),
6753                    };
6754
6755                    let ws = WebsocketStreams::new(
6756                        config,
6757                        vec![],
6758                        vec!["path1".to_string(), "path2".to_string()],
6759                    );
6760
6761                    ws.clone().connect(vec!["stream1".into()]).await.unwrap();
6762
6763                    let pool_size = ws.configuration.mode.pool_size();
6764
6765                    for (i, conn) in ws.common.connection_pool.iter().enumerate() {
6766                        let expected = if i < pool_size { "path1" } else { "path2" };
6767                        let st = conn.state.lock().await;
6768                        assert_eq!(st.url_path.as_deref(), Some(expected));
6769                    }
6770                });
6771            }
6772
6773            #[test]
6774            fn refused_returns_error() {
6775                TOKIO_SHARED_RT.block_on(async {
6776                    let ws = create_websocket_streams(Some("ws://127.0.0.1:9"), None, None);
6777                    let res = ws.connect(vec!["stream1".into()]).await;
6778                    assert!(res.is_err());
6779                });
6780            }
6781
6782            #[test]
6783            fn invalid_url_returns_error() {
6784                TOKIO_SHARED_RT.block_on(async {
6785                    let ws = create_websocket_streams(Some("not-a-url"), None, None);
6786                    let res = ws.connect(vec!["s".into()]).await;
6787                    assert!(res.is_err());
6788                });
6789            }
6790        }
6791
6792        mod disconnect {
6793            use super::*;
6794
6795            #[test]
6796            fn disconnect_clears_state_and_streams() {
6797                TOKIO_SHARED_RT.block_on(async {
6798                    let ws = create_websocket_streams(None, None, None);
6799                    let conn = &ws.common.connection_pool[0];
6800                    {
6801                        let mut state = conn.state.lock().await;
6802                        state.stream_callbacks.insert("s1".to_string(), Vec::new());
6803                        state.pending_subscriptions.push_back("s2".to_string());
6804                    }
6805                    {
6806                        let mut map = ws.connection_streams.lock().await;
6807                        map.insert("s3".to_string(), Arc::clone(conn));
6808                    }
6809
6810                    let res = ws.disconnect().await;
6811                    assert!(res.is_ok());
6812
6813                    let state = conn.state.lock().await;
6814                    assert!(state.stream_callbacks.is_empty());
6815                    assert!(state.pending_subscriptions.is_empty());
6816
6817                    let map = ws.connection_streams.lock().await;
6818                    assert!(map.is_empty());
6819                });
6820            }
6821        }
6822
6823        mod subscribe {
6824            use super::*;
6825
6826            #[test]
6827            fn empty_list_does_nothing() {
6828                TOKIO_SHARED_RT.block_on(async {
6829                    let ws = create_websocket_streams(None, None, None);
6830                    ws.clone().subscribe(Vec::new(), None, None).await;
6831                    let map = ws.connection_streams.lock().await;
6832                    assert!(map.is_empty());
6833                });
6834            }
6835
6836            #[test]
6837            fn queue_when_not_ready() {
6838                TOKIO_SHARED_RT.block_on(async {
6839                    let ws = create_websocket_streams(None, None, None);
6840                    let conn = ws.common.connection_pool[0].clone();
6841                    ws.clone().subscribe(vec!["s1".into()], None, None).await;
6842                    let state = conn.state.lock().await;
6843                    let pending: Vec<String> =
6844                        state.pending_subscriptions.iter().cloned().collect();
6845                    assert_eq!(pending, vec!["s1".to_string()]);
6846                });
6847            }
6848
6849            #[test]
6850            fn only_one_subscription_per_stream() {
6851                TOKIO_SHARED_RT.block_on(async {
6852                    let ws = create_websocket_streams(None, None, None);
6853                    let conn = ws.common.connection_pool[0].clone();
6854                    ws.clone().subscribe(vec!["s1".into()], None, None).await;
6855                    ws.clone().subscribe(vec!["s1".into()], None, None).await;
6856                    let state = conn.state.lock().await;
6857                    let pending: Vec<String> =
6858                        state.pending_subscriptions.iter().cloned().collect();
6859                    assert_eq!(pending, vec!["s1".to_string()]);
6860                });
6861            }
6862
6863            #[test]
6864            fn multiple_streams_assigned() {
6865                TOKIO_SHARED_RT.block_on(async {
6866                    let ws = create_websocket_streams(None, None, None);
6867                    ws.clone()
6868                        .subscribe(vec!["s1".into(), "s2".into()], None, None)
6869                        .await;
6870                    let map = ws.connection_streams.lock().await;
6871                    assert!(map.contains_key("s1"));
6872                    assert!(map.contains_key("s2"));
6873                });
6874            }
6875
6876            #[test]
6877            fn existing_stream_not_reassigned() {
6878                TOKIO_SHARED_RT.block_on(async {
6879                    let ws = create_websocket_streams(None, None, None);
6880                    ws.clone().subscribe(vec!["s1".into()], None, None).await;
6881                    let first_id = {
6882                        let map = ws.connection_streams.lock().await;
6883                        map.get("s1").unwrap().id.clone()
6884                    };
6885                    ws.clone()
6886                        .subscribe(vec!["s1".into(), "s2".into()], None, None)
6887                        .await;
6888                    let map = ws.connection_streams.lock().await;
6889                    let second_id = map.get("s1").unwrap().id.clone();
6890                    assert_eq!(first_id, second_id);
6891                    assert!(map.contains_key("s2"));
6892                });
6893            }
6894
6895            #[test]
6896            fn queue_when_not_ready_with_url_path() {
6897                TOKIO_SHARED_RT.block_on(async {
6898                    let ws = create_websocket_streams(None, None, None);
6899
6900                    let conn = ws.common.connection_pool[0].clone();
6901                    {
6902                        let mut st = conn.state.lock().await;
6903                        st.ws_write_tx = None;
6904                        st.url_path = Some("path1".to_string());
6905                        st.reconnection_pending = false;
6906                        st.close_initiated = false;
6907                    }
6908
6909                    ws.clone()
6910                        .subscribe(vec!["s1".into()], None, Some("path1"))
6911                        .await;
6912
6913                    let state = conn.state.lock().await;
6914                    let pending: Vec<String> =
6915                        state.pending_subscriptions.iter().cloned().collect();
6916                    assert_eq!(pending, vec!["s1".to_string()]);
6917                });
6918            }
6919
6920            #[test]
6921            fn only_one_subscription_per_stream_per_url_path() {
6922                TOKIO_SHARED_RT.block_on(async {
6923                    let ws = create_websocket_streams(None, None, None);
6924
6925                    let conn = ws.common.connection_pool[0].clone();
6926                    {
6927                        let mut st = conn.state.lock().await;
6928                        st.ws_write_tx = None;
6929                        st.url_path = Some("path1".to_string());
6930                        st.reconnection_pending = false;
6931                        st.close_initiated = false;
6932                    }
6933
6934                    ws.clone()
6935                        .subscribe(vec!["s1".into()], None, Some("path1"))
6936                        .await;
6937                    ws.clone()
6938                        .subscribe(vec!["s1".into()], None, Some("path1"))
6939                        .await;
6940
6941                    let state = conn.state.lock().await;
6942                    let pending: Vec<String> =
6943                        state.pending_subscriptions.iter().cloned().collect();
6944                    assert_eq!(pending, vec!["s1".to_string()]);
6945                });
6946            }
6947
6948            #[test]
6949            fn same_stream_can_be_subscribed_on_different_url_paths() {
6950                TOKIO_SHARED_RT.block_on(async {
6951                    let ws = create_websocket_streams(None, None, None);
6952
6953                    let conn1 = ws.common.connection_pool[0].clone();
6954                    let conn2 = ws.common.connection_pool[1].clone();
6955
6956                    {
6957                        let mut st1 = conn1.state.lock().await;
6958                        st1.ws_write_tx = None;
6959                        st1.url_path = Some("path1".to_string());
6960                        st1.reconnection_pending = false;
6961                        st1.close_initiated = false;
6962                    }
6963                    {
6964                        let mut st2 = conn2.state.lock().await;
6965                        st2.ws_write_tx = None;
6966                        st2.url_path = Some("path2".to_string());
6967                        st2.reconnection_pending = false;
6968                        st2.close_initiated = false;
6969                    }
6970
6971                    ws.clone()
6972                        .subscribe(vec!["s1".into()], None, Some("path1"))
6973                        .await;
6974                    ws.clone()
6975                        .subscribe(vec!["s1".into()], None, Some("path2"))
6976                        .await;
6977
6978                    let map = ws.connection_streams.lock().await;
6979                    assert!(map.contains_key("path1::s1"));
6980                    assert!(map.contains_key("path2::s1"));
6981                });
6982            }
6983        }
6984
6985        mod unsubscribe {
6986            use super::*;
6987
6988            #[test]
6989            fn removes_stream_with_no_callbacks() {
6990                TOKIO_SHARED_RT.block_on(async {
6991                    let ws = create_websocket_streams(None, None, None);
6992                    let conn = ws.common.connection_pool[0].clone();
6993
6994                    {
6995                        let (tx, _rx) = unbounded_channel::<Message>();
6996                        let mut st = conn.state.lock().await;
6997                        st.ws_write_tx = Some(tx);
6998                    }
6999
7000                    {
7001                        let mut map = ws.connection_streams.lock().await;
7002                        map.insert("s1".to_string(), conn.clone());
7003                    }
7004                    {
7005                        let mut st = conn.state.lock().await;
7006                        st.stream_callbacks.insert("s1".to_string(), Vec::new());
7007                    }
7008
7009                    ws.unsubscribe(vec!["s1".to_string()], None, None).await;
7010
7011                    assert!(!ws.connection_streams.lock().await.contains_key("s1"));
7012                    assert!(!conn.state.lock().await.stream_callbacks.contains_key("s1"));
7013                });
7014            }
7015
7016            #[test]
7017            fn preserves_stream_with_callbacks() {
7018                TOKIO_SHARED_RT.block_on(async {
7019                    let ws = create_websocket_streams(None, None, None);
7020                    let conn = ws.common.connection_pool[1].clone();
7021
7022                    {
7023                        let mut map = ws.connection_streams.lock().await;
7024                        map.insert("s2".to_string(), conn.clone());
7025                    }
7026                    {
7027                        let mut state = conn.state.lock().await;
7028                        state
7029                            .stream_callbacks
7030                            .insert("s2".to_string(), vec![Arc::new(|_: &Value| {})]);
7031                    }
7032
7033                    ws.unsubscribe(vec!["s2".to_string()], None, None).await;
7034
7035                    assert!(ws.connection_streams.lock().await.contains_key("s2"));
7036                    assert!(conn.state.lock().await.stream_callbacks.contains_key("s2"));
7037                });
7038            }
7039
7040            #[test]
7041            fn does_not_send_if_callbacks_exist() {
7042                TOKIO_SHARED_RT.block_on(async {
7043                    let ws = create_websocket_streams(None, None, None);
7044                    let conn = ws.common.connection_pool[0].clone();
7045                    {
7046                        let mut map = ws.connection_streams.lock().await;
7047                        map.insert("s1".to_string(), conn.clone());
7048                    }
7049                    {
7050                        let mut state = conn.state.lock().await;
7051                        state.stream_callbacks.insert(
7052                            "s1".to_string(),
7053                            vec![Arc::new(|_: &Value| {}), Arc::new(|_: &Value| {})],
7054                        );
7055                    }
7056                    ws.unsubscribe(vec!["s1".into()], None, None).await;
7057                    assert!(ws.connection_streams.lock().await.contains_key("s1"));
7058                    assert!(conn.state.lock().await.stream_callbacks.contains_key("s1"));
7059                });
7060            }
7061
7062            #[test]
7063            fn warns_if_not_associated() {
7064                TOKIO_SHARED_RT.block_on(async {
7065                    let ws = create_websocket_streams(None, None, None);
7066                    ws.unsubscribe(vec!["nope".into()], None, None).await;
7067                });
7068            }
7069
7070            #[test]
7071            fn empty_list_does_nothing() {
7072                TOKIO_SHARED_RT.block_on(async {
7073                    let ws = create_websocket_streams(None, None, None);
7074                    let before = ws.connection_streams.lock().await.len();
7075                    ws.unsubscribe(Vec::<String>::new(), None, None).await;
7076                    let after = ws.connection_streams.lock().await.len();
7077                    assert_eq!(before, after);
7078                });
7079            }
7080
7081            #[test]
7082            fn invalid_custom_id_falls_back() {
7083                TOKIO_SHARED_RT.block_on(async {
7084                    let ws = create_websocket_streams(None, None, None);
7085                    let conn = ws.common.connection_pool[0].clone();
7086                    {
7087                        let mut map = ws.connection_streams.lock().await;
7088                        map.insert("foo".to_string(), conn.clone());
7089                    }
7090                    {
7091                        let mut state = conn.state.lock().await;
7092                        let (tx, _rx) = unbounded_channel();
7093                        state.ws_write_tx = Some(tx);
7094                        state.stream_callbacks.insert("foo".to_string(), Vec::new());
7095                    }
7096                    ws.unsubscribe(
7097                        vec!["foo".into()],
7098                        Some(StreamId::Str("bad-id".into())),
7099                        None,
7100                    )
7101                    .await;
7102                    assert!(!ws.connection_streams.lock().await.contains_key("foo"));
7103                });
7104            }
7105
7106            #[test]
7107            fn removes_even_without_write_channel() {
7108                TOKIO_SHARED_RT.block_on(async {
7109                    let ws = create_websocket_streams(None, None, None);
7110                    let conn = ws.common.connection_pool[0].clone();
7111                    {
7112                        let mut map = ws.connection_streams.lock().await;
7113                        map.insert("x".to_string(), conn.clone());
7114                    }
7115                    {
7116                        let mut state = conn.state.lock().await;
7117                        let (tx, _rx) = unbounded_channel();
7118                        state.ws_write_tx = Some(tx);
7119                        state.stream_callbacks.insert("x".to_string(), Vec::new());
7120                    }
7121                    ws.unsubscribe(vec!["x".into()], None, None).await;
7122                    assert!(!ws.connection_streams.lock().await.contains_key("x"));
7123                });
7124            }
7125
7126            #[test]
7127            fn removes_stream_with_no_callbacks_with_url_path() {
7128                TOKIO_SHARED_RT.block_on(async {
7129                    let ws = create_websocket_streams(None, None, None);
7130                    let conn = ws.common.connection_pool[0].clone();
7131
7132                    {
7133                        let (tx, _rx) = unbounded_channel::<Message>();
7134                        let mut st = conn.state.lock().await;
7135                        st.ws_write_tx = Some(tx);
7136                        st.url_path = Some("path1".to_string());
7137                    }
7138
7139                    {
7140                        let mut map = ws.connection_streams.lock().await;
7141                        map.insert("path1::s1".to_string(), conn.clone());
7142                    }
7143                    {
7144                        let mut st = conn.state.lock().await;
7145                        st.stream_callbacks
7146                            .insert("path1::s1".to_string(), Vec::new());
7147                    }
7148
7149                    ws.unsubscribe(vec!["s1".to_string()], None, Some("path1"))
7150                        .await;
7151
7152                    assert!(!ws.connection_streams.lock().await.contains_key("path1::s1"));
7153                    assert!(
7154                        !conn
7155                            .state
7156                            .lock()
7157                            .await
7158                            .stream_callbacks
7159                            .contains_key("path1::s1")
7160                    );
7161                });
7162            }
7163
7164            #[test]
7165            fn preserves_stream_with_callbacks_with_url_path() {
7166                TOKIO_SHARED_RT.block_on(async {
7167                    let ws = create_websocket_streams(None, None, None);
7168                    let conn = ws.common.connection_pool[0].clone();
7169
7170                    {
7171                        let (tx, _rx) = unbounded_channel::<Message>();
7172                        let mut st = conn.state.lock().await;
7173                        st.ws_write_tx = Some(tx);
7174                        st.url_path = Some("path1".to_string());
7175                    }
7176
7177                    {
7178                        let mut map = ws.connection_streams.lock().await;
7179                        map.insert("path1::s2".to_string(), conn.clone());
7180                    }
7181                    {
7182                        let mut state = conn.state.lock().await;
7183                        state
7184                            .stream_callbacks
7185                            .insert("path1::s2".to_string(), vec![Arc::new(|_: &Value| {})]);
7186                    }
7187
7188                    ws.unsubscribe(vec!["s2".to_string()], None, Some("path1"))
7189                        .await;
7190
7191                    assert!(ws.connection_streams.lock().await.contains_key("path1::s2"));
7192                    assert!(
7193                        conn.state
7194                            .lock()
7195                            .await
7196                            .stream_callbacks
7197                            .contains_key("path1::s2")
7198                    );
7199                });
7200            }
7201
7202            #[test]
7203            fn url_path_mismatch_does_not_remove_other_path_subscription() {
7204                TOKIO_SHARED_RT.block_on(async {
7205                    let ws = create_websocket_streams(None, None, None);
7206                    let conn = ws.common.connection_pool[0].clone();
7207
7208                    {
7209                        let (tx, _rx) = unbounded_channel::<Message>();
7210                        let mut st = conn.state.lock().await;
7211                        st.ws_write_tx = Some(tx);
7212                        st.url_path = Some("path1".to_string());
7213                    }
7214
7215                    {
7216                        let mut map = ws.connection_streams.lock().await;
7217                        map.insert("path1::s1".to_string(), conn.clone());
7218                    }
7219                    {
7220                        let mut st = conn.state.lock().await;
7221                        st.stream_callbacks
7222                            .insert("path1::s1".to_string(), Vec::new());
7223                    }
7224
7225                    ws.unsubscribe(vec!["s1".to_string()], None, Some("path2"))
7226                        .await;
7227
7228                    assert!(ws.connection_streams.lock().await.contains_key("path1::s1"));
7229                    assert!(
7230                        conn.state
7231                            .lock()
7232                            .await
7233                            .stream_callbacks
7234                            .contains_key("path1::s1")
7235                    );
7236                });
7237            }
7238        }
7239
7240        mod is_subscribed {
7241            use super::*;
7242
7243            #[test]
7244            fn returns_false_when_not_subscribed() {
7245                TOKIO_SHARED_RT.block_on(async {
7246                    let ws = create_websocket_streams(None, None, None);
7247                    assert!(!ws.is_subscribed("unknown").await);
7248                });
7249            }
7250
7251            #[test]
7252            fn returns_true_when_subscribed() {
7253                TOKIO_SHARED_RT.block_on(async {
7254                    let ws = create_websocket_streams(None, None, None);
7255                    let conn = ws.common.connection_pool[0].clone();
7256                    {
7257                        let mut map = ws.connection_streams.lock().await;
7258                        map.insert("stream1".to_string(), conn);
7259                    }
7260                    assert!(ws.is_subscribed("stream1").await);
7261                });
7262            }
7263
7264            #[test]
7265            fn returns_true_when_subscribed_with_url_path_key() {
7266                TOKIO_SHARED_RT.block_on(async {
7267                    let ws = create_websocket_streams(None, None, None);
7268                    let conn = ws.common.connection_pool[0].clone();
7269                    {
7270                        let mut map = ws.connection_streams.lock().await;
7271                        map.insert("path1::stream1".to_string(), conn);
7272                    }
7273                    assert!(ws.is_subscribed("stream1").await);
7274                });
7275            }
7276
7277            #[test]
7278            fn returns_true_when_same_stream_subscribed_on_multiple_paths() {
7279                TOKIO_SHARED_RT.block_on(async {
7280                    let ws = create_websocket_streams(None, None, None);
7281                    let conn1 = ws.common.connection_pool[0].clone();
7282                    let conn2 = ws.common.connection_pool[1].clone();
7283                    {
7284                        let mut map = ws.connection_streams.lock().await;
7285                        map.insert("path1::stream1".to_string(), conn1);
7286                        map.insert("path2::stream1".to_string(), conn2);
7287                    }
7288                    assert!(ws.is_subscribed("stream1").await);
7289                });
7290            }
7291
7292            #[test]
7293            fn returns_false_when_only_similar_suffix_exists() {
7294                TOKIO_SHARED_RT.block_on(async {
7295                    let ws = create_websocket_streams(None, None, None);
7296                    let conn = ws.common.connection_pool[0].clone();
7297                    {
7298                        let mut map = ws.connection_streams.lock().await;
7299                        map.insert("path1::stream10".to_string(), conn);
7300                    }
7301                    assert!(!ws.is_subscribed("stream1").await);
7302                });
7303            }
7304        }
7305
7306        mod stream_key {
7307            use super::*;
7308
7309            #[test]
7310            fn stream_key_without_url_path_returns_stream() {
7311                TOKIO_SHARED_RT.block_on(async {
7312                    let ws = create_websocket_streams(None, None, None);
7313                    assert_eq!(ws.stream_key("s1", None), "s1");
7314                });
7315            }
7316
7317            #[test]
7318            fn stream_key_with_empty_url_path_returns_stream() {
7319                TOKIO_SHARED_RT.block_on(async {
7320                    let ws = create_websocket_streams(None, None, None);
7321                    assert_eq!(ws.stream_key("s1", Some("")), "s1");
7322                });
7323            }
7324
7325            #[test]
7326            fn stream_key_with_url_path_prefixes_stream() {
7327                TOKIO_SHARED_RT.block_on(async {
7328                    let ws = create_websocket_streams(None, None, None);
7329                    assert_eq!(ws.stream_key("s1", Some("path1")), "path1::s1");
7330                });
7331            }
7332
7333            #[test]
7334            fn stream_key_distinguishes_paths() {
7335                TOKIO_SHARED_RT.block_on(async {
7336                    let ws = create_websocket_streams(None, None, None);
7337                    assert_eq!(ws.stream_key("s1", Some("path1")), "path1::s1");
7338                    assert_eq!(ws.stream_key("s1", Some("path2")), "path2::s1");
7339                });
7340            }
7341        }
7342
7343        mod prepare_url {
7344            use super::*;
7345
7346            #[test]
7347            fn without_time_unit_returns_base_url() {
7348                TOKIO_SHARED_RT.block_on(async {
7349                    let conns = vec![
7350                        WebsocketConnection::new("c1"),
7351                        WebsocketConnection::new("c2"),
7352                    ];
7353                    let config = ConfigurationWebsocketStreams {
7354                        ws_url: Some("wss://example".to_string()),
7355                        mode: WebsocketMode::Single,
7356                        reconnect_delay: 100,
7357                        time_unit: None,
7358                        agent: None,
7359                        user_agent: build_user_agent("product"),
7360                    };
7361                    let ws = WebsocketStreams::new(config, conns, vec![]);
7362                    let url = ws.prepare_url(&["s1".into(), "s2".into()], None);
7363                    assert_eq!(url, "wss://example/stream?streams=s1/s2");
7364                });
7365            }
7366
7367            #[test]
7368            fn with_time_unit_appends_parameter() {
7369                TOKIO_SHARED_RT.block_on(async {
7370                    let conns = vec![WebsocketConnection::new("c1")];
7371                    let config = ConfigurationWebsocketStreams {
7372                        ws_url: Some("wss://example".to_string()),
7373                        mode: WebsocketMode::Single,
7374                        reconnect_delay: 100,
7375                        time_unit: Some(TimeUnit::Millisecond),
7376                        agent: None,
7377                        user_agent: build_user_agent("product"),
7378                    };
7379                    let ws = WebsocketStreams::new(config, conns, vec![]);
7380                    let url = ws.prepare_url(&["a".into()], None);
7381                    assert_eq!(url, "wss://example/stream?streams=a&timeUnit=millisecond");
7382                });
7383            }
7384
7385            #[test]
7386            fn multiple_streams_and_time_unit() {
7387                TOKIO_SHARED_RT.block_on(async {
7388                    let conns = vec![WebsocketConnection::new("c1")];
7389                    let config = ConfigurationWebsocketStreams {
7390                        ws_url: Some("wss://example".to_string()),
7391                        mode: WebsocketMode::Single,
7392                        reconnect_delay: 100,
7393                        time_unit: Some(TimeUnit::Microsecond),
7394                        agent: None,
7395                        user_agent: build_user_agent("product"),
7396                    };
7397                    let ws = WebsocketStreams::new(config, conns, vec![]);
7398                    let url = ws.prepare_url(&["x".into(), "y".into(), "z".into()], None);
7399                    assert_eq!(
7400                        url,
7401                        "wss://example/stream?streams=x/y/z&timeUnit=microsecond"
7402                    );
7403                });
7404            }
7405
7406            #[test]
7407            fn with_url_path_prefixes_base_url() {
7408                TOKIO_SHARED_RT.block_on(async {
7409                    let conns = vec![WebsocketConnection::new("c1")];
7410                    let config = ConfigurationWebsocketStreams {
7411                        ws_url: Some("wss://example".to_string()),
7412                        mode: WebsocketMode::Single,
7413                        reconnect_delay: 100,
7414                        time_unit: None,
7415                        agent: None,
7416                        user_agent: build_user_agent("product"),
7417                    };
7418                    let ws = WebsocketStreams::new(config, conns, vec![]);
7419                    let url = ws.prepare_url(["s1".into()].as_ref(), Some("path1"));
7420                    assert_eq!(url, "wss://example/path1/stream?streams=s1");
7421                });
7422            }
7423
7424            #[test]
7425            fn with_url_path_and_time_unit_appends_parameter() {
7426                TOKIO_SHARED_RT.block_on(async {
7427                    let conns = vec![WebsocketConnection::new("c1")];
7428                    let config = ConfigurationWebsocketStreams {
7429                        ws_url: Some("wss://example".to_string()),
7430                        mode: WebsocketMode::Single,
7431                        reconnect_delay: 100,
7432                        time_unit: Some(TimeUnit::Millisecond),
7433                        agent: None,
7434                        user_agent: build_user_agent("product"),
7435                    };
7436                    let ws = WebsocketStreams::new(config, conns, vec![]);
7437                    let url = ws.prepare_url(["a".into()].as_ref(), Some("path1"));
7438                    assert_eq!(
7439                        url,
7440                        "wss://example/path1/stream?streams=a&timeUnit=millisecond"
7441                    );
7442                });
7443            }
7444
7445            #[test]
7446            fn url_path_distinguishes_urls_for_same_streams() {
7447                TOKIO_SHARED_RT.block_on(async {
7448                    let conns = vec![WebsocketConnection::new("c1")];
7449                    let config = ConfigurationWebsocketStreams {
7450                        ws_url: Some("wss://example".to_string()),
7451                        mode: WebsocketMode::Single,
7452                        reconnect_delay: 100,
7453                        time_unit: None,
7454                        agent: None,
7455                        user_agent: build_user_agent("product"),
7456                    };
7457                    let ws = WebsocketStreams::new(config, conns, vec![]);
7458                    let u1 = ws.prepare_url(["s1".into()].as_ref(), Some("path1"));
7459                    let u2 = ws.prepare_url(["s1".into()].as_ref(), Some("path2"));
7460                    assert_eq!(u1, "wss://example/path1/stream?streams=s1");
7461                    assert_eq!(u2, "wss://example/path2/stream?streams=s1");
7462                });
7463            }
7464        }
7465
7466        mod handle_stream_assignment {
7467            use super::*;
7468
7469            #[test]
7470            fn assigns_new_streams_to_connections() {
7471                TOKIO_SHARED_RT.block_on(async {
7472                    let ws = create_websocket_streams(None, None, None);
7473                    let groups = ws
7474                        .clone()
7475                        .handle_stream_assignment(vec!["s1".into(), "s2".into()], None)
7476                        .await;
7477                    let mut seen_streams = HashSet::new();
7478                    for (_conn, streams) in &groups {
7479                        for s in streams {
7480                            seen_streams.insert(s);
7481                        }
7482                    }
7483                    assert_eq!(
7484                        seen_streams,
7485                        ["s1".to_string(), "s2".to_string()].iter().collect()
7486                    );
7487                    assert_eq!(groups.len(), 1);
7488                });
7489            }
7490
7491            #[test]
7492            fn reuses_existing_connection_for_duplicate_stream() {
7493                TOKIO_SHARED_RT.block_on(async {
7494                    let ws = create_websocket_streams(None, None, None);
7495                    let _ = ws
7496                        .clone()
7497                        .handle_stream_assignment(vec!["s1".into()], None)
7498                        .await;
7499                    let groups = ws
7500                        .clone()
7501                        .handle_stream_assignment(vec!["s1".into(), "s3".into()], None)
7502                        .await;
7503                    let mut all_streams = Vec::new();
7504                    for (_conn, streams) in groups {
7505                        all_streams.extend(streams);
7506                    }
7507                    all_streams.sort();
7508                    assert_eq!(all_streams, vec!["s1".to_string(), "s3".to_string()]);
7509                });
7510            }
7511
7512            #[test]
7513            fn empty_stream_list_returns_empty() {
7514                TOKIO_SHARED_RT.block_on(async {
7515                    let ws = create_websocket_streams(None, None, None);
7516                    let groups = ws.clone().handle_stream_assignment(vec![], None).await;
7517                    assert!(groups.is_empty());
7518                });
7519            }
7520
7521            #[test]
7522            fn closed_or_reconnecting_forces_reassignment_of_stream() {
7523                TOKIO_SHARED_RT.block_on(async {
7524                    let ws = create_websocket_streams(None, None, None);
7525                    let mut groups = ws
7526                        .clone()
7527                        .handle_stream_assignment(vec!["s1".into()], None)
7528                        .await;
7529                    let (conn, _) = groups.pop().unwrap();
7530                    {
7531                        let mut st = conn.state.lock().await;
7532                        st.close_initiated = true;
7533                    }
7534                    let groups2 = ws
7535                        .clone()
7536                        .handle_stream_assignment(vec!["s2".into()], None)
7537                        .await;
7538                    assert_eq!(groups2.len(), 1);
7539                    let (_new_conn, streams) = &groups2[0];
7540                    assert_eq!(streams, &vec!["s2".to_string()]);
7541                });
7542            }
7543
7544            #[test]
7545            fn no_available_connections_falls_back_to_one() {
7546                TOKIO_SHARED_RT.block_on(async {
7547                    let ws = create_websocket_streams(None, Some(vec![]), None);
7548                    let assigned = ws.handle_stream_assignment(vec!["foo".into()], None).await;
7549                    assert_eq!(assigned.len(), 1);
7550                    let (_conn, streams) = &assigned[0];
7551                    assert_eq!(streams.as_slice(), &["foo".to_string()]);
7552                });
7553            }
7554
7555            #[test]
7556            fn single_connection_groups_multiple_streams() {
7557                TOKIO_SHARED_RT.block_on(async {
7558                    let conn = WebsocketConnection::new("c1");
7559                    let ws = create_websocket_streams(None, Some(vec![conn.clone()]), None);
7560                    let assigned = ws
7561                        .handle_stream_assignment(vec!["s1".into(), "s2".into()], None)
7562                        .await;
7563                    assert_eq!(assigned.len(), 1);
7564                    let (assigned_conn, streams) = &assigned[0];
7565                    assert!(Arc::ptr_eq(assigned_conn, &conn));
7566                    assert_eq!(streams.len(), 2);
7567                    assert!(streams.contains(&"s1".to_string()));
7568                    assert!(streams.contains(&"s2".to_string()));
7569                });
7570            }
7571
7572            #[test]
7573            fn reuse_existing_healthy_connection() {
7574                TOKIO_SHARED_RT.block_on(async {
7575                    let conn = WebsocketConnection::new("c");
7576                    let ws = create_websocket_streams(None, Some(vec![conn.clone()]), None);
7577                    let _ = ws.handle_stream_assignment(vec!["s1".into()], None).await;
7578                    let second = ws.handle_stream_assignment(vec!["s1".into()], None).await;
7579                    assert_eq!(second.len(), 1);
7580                    let (assigned_conn, streams) = &second[0];
7581                    assert!(Arc::ptr_eq(assigned_conn, &conn));
7582                    assert_eq!(streams.as_slice(), &["s1".to_string()]);
7583                });
7584            }
7585
7586            #[test]
7587            fn mix_new_and_assigned_streams() {
7588                TOKIO_SHARED_RT.block_on(async {
7589                    let conn = WebsocketConnection::new("c");
7590                    let ws = create_websocket_streams(None, Some(vec![conn.clone()]), None);
7591                    let _ = ws
7592                        .handle_stream_assignment(vec!["s1".into(), "s2".into()], None)
7593                        .await;
7594                    let mixed = ws
7595                        .handle_stream_assignment(vec!["s2".into(), "s3".into()], None)
7596                        .await;
7597                    assert_eq!(mixed.len(), 1);
7598                    let (assigned_conn, streams) = &mixed[0];
7599                    assert!(Arc::ptr_eq(assigned_conn, &conn));
7600                    let mut got = streams.clone();
7601                    got.sort();
7602                    assert_eq!(got, vec!["s2".to_string(), "s3".to_string()]);
7603                });
7604            }
7605
7606            #[test]
7607            fn assigns_streams_with_url_path_keys() {
7608                TOKIO_SHARED_RT.block_on(async {
7609                    let ws = create_websocket_streams(None, None, None);
7610
7611                    let conn = ws.common.connection_pool[0].clone();
7612                    {
7613                        let mut st = conn.state.lock().await;
7614                        st.url_path = Some("path1".to_string());
7615                        st.ws_write_tx = None;
7616                        st.reconnection_pending = false;
7617                        st.close_initiated = false;
7618                    }
7619
7620                    let groups = ws
7621                        .handle_stream_assignment(vec!["s1".into(), "s2".into()], Some("path1"))
7622                        .await;
7623
7624                    let map = ws.connection_streams.lock().await;
7625                    assert!(map.contains_key("path1::s1"));
7626                    assert!(map.contains_key("path1::s2"));
7627                    assert_eq!(groups.len(), 1);
7628
7629                    let (_assigned_conn, streams) = &groups[0];
7630                    let mut got = streams.clone();
7631                    got.sort();
7632                    assert_eq!(got, vec!["s1".to_string(), "s2".to_string()]);
7633                });
7634            }
7635
7636            #[test]
7637            fn same_stream_on_different_paths_creates_distinct_keys() {
7638                TOKIO_SHARED_RT.block_on(async {
7639                    let ws = create_websocket_streams(None, None, None);
7640
7641                    let conn1 = ws.common.connection_pool[0].clone();
7642                    let conn2 = ws.common.connection_pool[1].clone();
7643
7644                    {
7645                        let mut st = conn1.state.lock().await;
7646                        st.url_path = Some("path1".to_string());
7647                        st.ws_write_tx = None;
7648                        st.reconnection_pending = false;
7649                        st.close_initiated = false;
7650                    }
7651                    {
7652                        let mut st = conn2.state.lock().await;
7653                        st.url_path = Some("path2".to_string());
7654                        st.ws_write_tx = None;
7655                        st.reconnection_pending = false;
7656                        st.close_initiated = false;
7657                    }
7658
7659                    let g1 = ws
7660                        .handle_stream_assignment(vec!["s1".into()], Some("path1"))
7661                        .await;
7662                    let g2 = ws
7663                        .handle_stream_assignment(vec!["s1".into()], Some("path2"))
7664                        .await;
7665
7666                    assert_eq!(g1.len(), 1);
7667                    assert_eq!(g2.len(), 1);
7668
7669                    let map = ws.connection_streams.lock().await;
7670                    assert!(map.contains_key("path1::s1"));
7671                    assert!(map.contains_key("path2::s1"));
7672                });
7673            }
7674
7675            #[test]
7676            fn reuses_existing_connection_for_same_path_and_stream() {
7677                TOKIO_SHARED_RT.block_on(async {
7678                    let ws = create_websocket_streams(None, None, None);
7679
7680                    let conn = ws.common.connection_pool[0].clone();
7681                    {
7682                        let mut st = conn.state.lock().await;
7683                        st.url_path = Some("path1".to_string());
7684                        st.ws_write_tx = None;
7685                        st.reconnection_pending = false;
7686                        st.close_initiated = false;
7687                    }
7688
7689                    let first = ws
7690                        .handle_stream_assignment(vec!["s1".into()], Some("path1"))
7691                        .await;
7692                    let second = ws
7693                        .handle_stream_assignment(vec!["s1".into(), "s2".into()], Some("path1"))
7694                        .await;
7695
7696                    assert_eq!(first.len(), 1);
7697                    assert_eq!(second.len(), 1);
7698
7699                    let map = ws.connection_streams.lock().await;
7700                    let c1 = map.get("path1::s1").unwrap().clone();
7701                    let c2 = map.get("path1::s2").unwrap().clone();
7702                    assert!(Arc::ptr_eq(&c1, &c2));
7703                });
7704            }
7705
7706            #[test]
7707            fn closed_or_reconnecting_forces_reassignment_with_url_path() {
7708                TOKIO_SHARED_RT.block_on(async {
7709                    let ws = create_websocket_streams(None, None, None);
7710
7711                    let conn1 = ws.common.connection_pool[0].clone();
7712                    let conn2 = ws.common.connection_pool[1].clone();
7713
7714                    {
7715                        let mut st = conn1.state.lock().await;
7716                        st.url_path = Some("path1".to_string());
7717                        st.ws_write_tx = None;
7718                        st.reconnection_pending = false;
7719                        st.close_initiated = false;
7720                    }
7721                    {
7722                        let mut st = conn2.state.lock().await;
7723                        st.url_path = Some("path1".to_string());
7724                        st.ws_write_tx = None;
7725                        st.reconnection_pending = false;
7726                        st.close_initiated = false;
7727                    }
7728
7729                    let _ = ws
7730                        .handle_stream_assignment(vec!["s1".into()], Some("path1"))
7731                        .await;
7732
7733                    {
7734                        let mut st = conn1.state.lock().await;
7735                        st.close_initiated = true;
7736                    }
7737
7738                    let _ = ws
7739                        .handle_stream_assignment(vec!["s1".into()], Some("path1"))
7740                        .await;
7741
7742                    let map = ws.connection_streams.lock().await;
7743                    let assigned = map.get("path1::s1").unwrap().clone();
7744                    assert!(!Arc::ptr_eq(&assigned, &conn1));
7745                });
7746            }
7747        }
7748
7749        mod send_subscription_payload {
7750            use super::*;
7751
7752            #[test]
7753            fn subscribe_payload_with_custom_id_fallbacks_if_invalid() {
7754                TOKIO_SHARED_RT.block_on(async {
7755                    let ws: Arc<WebsocketStreams> =
7756                        create_websocket_streams(Some("ws://example.com"), None, None);
7757                    let conn = &ws.common.connection_pool[0];
7758                    let (tx, mut rx) = unbounded_channel();
7759                    {
7760                        let mut st = conn.state.lock().await;
7761                        st.ws_write_tx = Some(tx);
7762                    }
7763                    let id = Some("badid".to_string());
7764                    ws.send_subscription_payload(
7765                        conn,
7766                        &vec!["s1".to_string()],
7767                        id.map(StreamId::from),
7768                    );
7769                    let msg = rx.recv().await.expect("no message sent");
7770                    if let Message::Text(txt) = msg {
7771                        let v: serde_json::Value = serde_json::from_str(&txt).unwrap();
7772                        assert_eq!(v["method"], "SUBSCRIBE");
7773                        let id = v["id"].as_str().unwrap();
7774                        assert_ne!(id, "badid");
7775                        assert!(Regex::new(r"^[0-9a-fA-F]{32}$").unwrap().is_match(id));
7776                    } else {
7777                        panic!("unexpected message: {msg:?}");
7778                    }
7779                });
7780            }
7781
7782            #[test]
7783            fn subscribe_payload_with_and_without_custom_string_id() {
7784                TOKIO_SHARED_RT.block_on(async {
7785                    let ws: Arc<WebsocketStreams> =
7786                        create_websocket_streams(Some("ws://unused"), None, None);
7787                    let conn = &ws.common.connection_pool[0];
7788                    let (tx, mut rx) = unbounded_channel();
7789                    {
7790                        let mut st = conn.state.lock().await;
7791                        st.ws_write_tx = Some(tx);
7792                    }
7793                    let id = Some("deadbeefdeadbeefdeadbeefdeadbeef".to_string());
7794                    ws.send_subscription_payload(
7795                        conn,
7796                        &vec!["a".to_string(), "b".to_string()],
7797                        id.map(StreamId::from),
7798                    );
7799                    let msg1 = rx.recv().await.unwrap();
7800                    ws.send_subscription_payload(conn, &vec!["x".to_string()], None);
7801                    let msg2 = rx.recv().await.unwrap();
7802
7803                    if let Message::Text(txt1) = msg1 {
7804                        let v1: serde_json::Value = serde_json::from_str(&txt1).unwrap();
7805                        assert_eq!(v1["id"], "deadbeefdeadbeefdeadbeefdeadbeef");
7806                        assert_eq!(
7807                            v1["params"].as_array().unwrap(),
7808                            &vec![serde_json::json!("a"), serde_json::json!("b")]
7809                        );
7810                    } else {
7811                        panic!()
7812                    }
7813
7814                    if let Message::Text(txt2) = msg2 {
7815                        let v2: serde_json::Value = serde_json::from_str(&txt2).unwrap();
7816                        assert_eq!(v2["method"], "SUBSCRIBE");
7817                        let params = v2["params"].as_array().unwrap();
7818                        assert_eq!(params.len(), 1);
7819                        assert_eq!(params[0], "x");
7820                        let id2 = v2["id"].as_str().unwrap();
7821                        assert!(Regex::new(r"^[0-9a-fA-F]{32}$").unwrap().is_match(id2));
7822                    } else {
7823                        panic!()
7824                    }
7825                });
7826            }
7827
7828            #[test]
7829            fn subscribe_payload_with_and_without_custom_integer_id() {
7830                TOKIO_SHARED_RT.block_on(async {
7831                    let ws: Arc<WebsocketStreams> =
7832                        create_websocket_streams(Some("ws://unused"), None, None);
7833                    ws.stream_id_is_strictly_number
7834                        .store(true, Ordering::Relaxed);
7835                    let conn = &ws.common.connection_pool[0];
7836                    let (tx, mut rx) = unbounded_channel();
7837                    {
7838                        let mut st = conn.state.lock().await;
7839                        st.ws_write_tx = Some(tx);
7840                    }
7841
7842                    let id = Some(123u32);
7843
7844                    ws.send_subscription_payload(
7845                        conn,
7846                        &vec!["a".to_string(), "b".to_string()],
7847                        id.map(StreamId::from),
7848                    );
7849                    let msg1 = rx.recv().await.unwrap();
7850
7851                    ws.send_subscription_payload(conn, &vec!["x".to_string()], None);
7852                    let msg2 = rx.recv().await.unwrap();
7853
7854                    if let Message::Text(txt1) = msg1 {
7855                        let v1: serde_json::Value = serde_json::from_str(&txt1).unwrap();
7856                        assert_eq!(v1["method"], "SUBSCRIBE");
7857                        assert_eq!(v1["id"].as_u64(), Some(123));
7858                        assert_eq!(
7859                            v1["params"].as_array().unwrap(),
7860                            &vec![serde_json::json!("a"), serde_json::json!("b")]
7861                        );
7862                    } else {
7863                        panic!("Expected Message::Text for msg1");
7864                    }
7865
7866                    if let Message::Text(txt2) = msg2 {
7867                        let v2: serde_json::Value = serde_json::from_str(&txt2).unwrap();
7868                        assert_eq!(v2["method"], "SUBSCRIBE");
7869
7870                        let params = v2["params"].as_array().unwrap();
7871                        assert_eq!(params.len(), 1);
7872                        assert_eq!(params[0], "x");
7873
7874                        let id2 = v2.get("id").expect("payload should contain id");
7875                        assert!(
7876                            id2.is_number(),
7877                            "expected numeric id in strict-number mode, got: {id2:?}"
7878                        );
7879                        let n = id2.as_u64().unwrap();
7880                        assert!(u32::try_from(n).is_ok(), "id should fit u32, got {n}");
7881                    } else {
7882                        panic!("Expected Message::Text for msg2");
7883                    }
7884                });
7885            }
7886        }
7887
7888        mod on_open {
7889            use super::*;
7890
7891            #[test]
7892            fn sends_pending_subscriptions() {
7893                TOKIO_SHARED_RT.block_on(async {
7894                    let ws: Arc<WebsocketStreams> =
7895                        create_websocket_streams(Some("ws://example.com"), None, None);
7896                    let conn = &ws.common.connection_pool[0];
7897                    let (tx, mut rx) = unbounded_channel();
7898                    {
7899                        let mut st = conn.state.lock().await;
7900                        st.ws_write_tx = Some(tx);
7901                        st.pending_subscriptions.push_back("foo".to_string());
7902                        st.pending_subscriptions.push_back("bar".to_string());
7903                    }
7904                    ws.on_open("ws://example.com".to_string(), conn.clone())
7905                        .await;
7906                    let msg = rx.recv().await.expect("no subscription sent");
7907                    if let Message::Text(txt) = msg {
7908                        let v: Value = serde_json::from_str(&txt).unwrap();
7909                        assert_eq!(v["method"], "SUBSCRIBE");
7910                        let params = v["params"].as_array().unwrap();
7911                        assert_eq!(
7912                            params,
7913                            &vec![Value::String("foo".into()), Value::String("bar".into())]
7914                        );
7915                    } else {
7916                        panic!("unexpected message: {msg:?}");
7917                    }
7918                    let st_after = conn.state.lock().await;
7919                    assert!(st_after.pending_subscriptions.is_empty());
7920                });
7921            }
7922
7923            #[test]
7924            fn with_no_pending_subscriptions_sends_nothing() {
7925                TOKIO_SHARED_RT.block_on(async {
7926                    let ws: Arc<WebsocketStreams> =
7927                        create_websocket_streams(Some("ws://example.com"), None, None);
7928                    let conn = &ws.common.connection_pool[0];
7929                    let (tx, mut rx) = unbounded_channel();
7930                    {
7931                        let mut st = conn.state.lock().await;
7932                        st.ws_write_tx = Some(tx);
7933                    }
7934                    ws.on_open("ws://example.com".to_string(), conn.clone())
7935                        .await;
7936                    assert!(rx.try_recv().is_err(), "unexpected message sent");
7937                });
7938            }
7939
7940            #[test]
7941            fn clears_pending_without_write_channel() {
7942                TOKIO_SHARED_RT.block_on(async {
7943                    let ws: Arc<WebsocketStreams> =
7944                        create_websocket_streams(Some("ws://example.com"), None, None);
7945                    let conn = &ws.common.connection_pool[0];
7946                    {
7947                        let mut st = conn.state.lock().await;
7948                        st.pending_subscriptions.push_back("solo".to_string());
7949                    }
7950                    ws.on_open("ws://example.com".to_string(), conn.clone())
7951                        .await;
7952                    let st_after = conn.state.lock().await;
7953                    assert!(st_after.pending_subscriptions.is_empty());
7954                });
7955            }
7956        }
7957
7958        mod on_message {
7959            use super::*;
7960
7961            #[test]
7962            fn invokes_registered_callback() {
7963                TOKIO_SHARED_RT.block_on(async {
7964                    let ws: Arc<WebsocketStreams> =
7965                        create_websocket_streams(Some("ws://example.com"), None, None);
7966                    let conn = &ws.common.connection_pool[0];
7967                    let called = Arc::new(AtomicBool::new(false));
7968                    let called_clone = called.clone();
7969
7970                    {
7971                        let mut st = conn.state.lock().await;
7972                        st.stream_callbacks
7973                            .entry("stream1".to_string())
7974                            .or_default()
7975                            .push(
7976                                (Box::new(move |_: &Value| {
7977                                    called_clone.store(true, Ordering::SeqCst);
7978                                })
7979                                    as Box<dyn Fn(&Value) + Send + Sync>)
7980                                    .into(),
7981                            );
7982                    }
7983
7984                    let msg = json!({
7985                        "stream": "stream1",
7986                        "data": { "key": "value" }
7987                    })
7988                    .to_string();
7989
7990                    ws.on_message(msg, conn.clone()).await;
7991
7992                    assert!(called.load(Ordering::SeqCst));
7993                });
7994            }
7995
7996            #[test]
7997            fn invokes_all_registered_callbacks() {
7998                TOKIO_SHARED_RT.block_on(async {
7999                    let ws: Arc<WebsocketStreams> =
8000                        create_websocket_streams(Some("ws://example.com"), None, None);
8001                    let conn = &ws.common.connection_pool[0];
8002                    let counter = Arc::new(AtomicUsize::new(0));
8003
8004                    {
8005                        let mut st = conn.state.lock().await;
8006                        let entry = st.stream_callbacks.entry("s".into()).or_default();
8007                        let c1 = counter.clone();
8008                        entry.push(
8009                            (Box::new(move |_: &Value| {
8010                                c1.fetch_add(1, Ordering::SeqCst);
8011                            }) as Box<dyn Fn(&Value) + Send + Sync>)
8012                                .into(),
8013                        );
8014                        let c2 = counter.clone();
8015                        entry.push(
8016                            (Box::new(move |_: &Value| {
8017                                c2.fetch_add(1, Ordering::SeqCst);
8018                            }) as Box<dyn Fn(&Value) + Send + Sync>)
8019                                .into(),
8020                        );
8021                    }
8022
8023                    let msg = json!({"stream":"s","data":42}).to_string();
8024                    ws.on_message(msg, conn.clone()).await;
8025
8026                    assert_eq!(counter.load(Ordering::SeqCst), 2);
8027                });
8028            }
8029
8030            #[test]
8031            fn handles_null_data_field() {
8032                TOKIO_SHARED_RT.block_on(async {
8033                    let ws: Arc<WebsocketStreams> =
8034                        create_websocket_streams(Some("ws://example.com"), None, None);
8035                    let conn = &ws.common.connection_pool[0];
8036                    let called = Arc::new(AtomicUsize::new(0));
8037                    {
8038                        let mut st = conn.state.lock().await;
8039                        st.stream_callbacks.entry("n".into()).or_default().push(
8040                            (Box::new({
8041                                let c = called.clone();
8042                                move |data: &Value| {
8043                                    if data.is_null() {
8044                                        c.fetch_add(1, Ordering::SeqCst);
8045                                    }
8046                                }
8047                            }) as Box<dyn Fn(&Value) + Send + Sync>)
8048                                .into(),
8049                        );
8050                    }
8051                    let msg = json!({"stream":"n","data":null}).to_string();
8052                    ws.on_message(msg, conn.clone()).await;
8053                    assert_eq!(called.load(Ordering::SeqCst), 1);
8054                });
8055            }
8056
8057            #[test]
8058            fn with_invalid_json_does_not_panic() {
8059                TOKIO_SHARED_RT.block_on(async {
8060                    let ws: Arc<WebsocketStreams> =
8061                        create_websocket_streams(Some("ws://example.com"), None, None);
8062                    let conn = &ws.common.connection_pool[0];
8063                    let bad = "not a json";
8064                    ws.on_message(bad.to_string(), conn.clone()).await;
8065                });
8066            }
8067
8068            #[test]
8069            fn without_stream_field_does_nothing() {
8070                TOKIO_SHARED_RT.block_on(async {
8071                    let ws: Arc<WebsocketStreams> =
8072                        create_websocket_streams(Some("ws://example.com"), None, None);
8073                    let conn = &ws.common.connection_pool[0];
8074                    let msg = json!({ "data": { "foo": 1 } }).to_string();
8075                    ws.on_message(msg, conn.clone()).await;
8076                });
8077            }
8078
8079            #[test]
8080            fn with_unregistered_stream_does_not_panic() {
8081                TOKIO_SHARED_RT.block_on(async {
8082                    let ws: Arc<WebsocketStreams> =
8083                        create_websocket_streams(Some("ws://example.com"), None, None);
8084                    let conn = &ws.common.connection_pool[0];
8085                    let msg = json!({
8086                        "stream": "nope",
8087                        "data": { "foo": 1 }
8088                    })
8089                    .to_string();
8090                    ws.on_message(msg, conn.clone()).await;
8091                });
8092            }
8093
8094            #[test]
8095            fn invokes_registered_callback_with_url_path_key() {
8096                TOKIO_SHARED_RT.block_on(async {
8097                    let ws: Arc<WebsocketStreams> =
8098                        create_websocket_streams(Some("ws://example.com"), None, None);
8099                    let conn = &ws.common.connection_pool[0];
8100
8101                    {
8102                        let mut st = conn.state.lock().await;
8103                        st.url_path = Some("path1".to_string());
8104                    }
8105
8106                    let called = Arc::new(AtomicBool::new(false));
8107                    let called_clone = called.clone();
8108
8109                    {
8110                        let mut st = conn.state.lock().await;
8111                        st.stream_callbacks
8112                            .entry("path1::stream1".to_string())
8113                            .or_default()
8114                            .push(
8115                                (Box::new(move |_: &Value| {
8116                                    called_clone.store(true, Ordering::SeqCst);
8117                                })
8118                                    as Box<dyn Fn(&Value) + Send + Sync>)
8119                                    .into(),
8120                            );
8121                    }
8122
8123                    let msg = json!({
8124                        "stream": "stream1",
8125                        "data": { "key": "value" }
8126                    })
8127                    .to_string();
8128
8129                    ws.on_message(msg, conn.clone()).await;
8130
8131                    assert!(called.load(Ordering::SeqCst));
8132                });
8133            }
8134
8135            #[test]
8136            fn does_not_invoke_callback_when_url_path_mismatch() {
8137                TOKIO_SHARED_RT.block_on(async {
8138                    let ws: Arc<WebsocketStreams> =
8139                        create_websocket_streams(Some("ws://example.com"), None, None);
8140                    let conn = &ws.common.connection_pool[0];
8141
8142                    {
8143                        let mut st = conn.state.lock().await;
8144                        st.url_path = Some("path2".to_string());
8145                    }
8146
8147                    let called = Arc::new(AtomicBool::new(false));
8148                    let called_clone = called.clone();
8149
8150                    {
8151                        let mut st = conn.state.lock().await;
8152                        st.stream_callbacks
8153                            .entry("path1::stream1".to_string())
8154                            .or_default()
8155                            .push(
8156                                (Box::new(move |_: &Value| {
8157                                    called_clone.store(true, Ordering::SeqCst);
8158                                })
8159                                    as Box<dyn Fn(&Value) + Send + Sync>)
8160                                    .into(),
8161                            );
8162                    }
8163
8164                    let msg = json!({
8165                        "stream": "stream1",
8166                        "data": { "key": "value" }
8167                    })
8168                    .to_string();
8169
8170                    ws.on_message(msg, conn.clone()).await;
8171
8172                    assert!(!called.load(Ordering::SeqCst));
8173                });
8174            }
8175
8176            #[test]
8177            fn invokes_only_callbacks_for_current_url_path_when_both_exist() {
8178                TOKIO_SHARED_RT.block_on(async {
8179                    let ws: Arc<WebsocketStreams> =
8180                        create_websocket_streams(Some("ws://example.com"), None, None);
8181                    let conn = &ws.common.connection_pool[0];
8182
8183                    {
8184                        let mut st = conn.state.lock().await;
8185                        st.url_path = Some("path1".to_string());
8186                    }
8187
8188                    let c1 = Arc::new(AtomicUsize::new(0));
8189                    let c2 = Arc::new(AtomicUsize::new(0));
8190
8191                    {
8192                        let mut st = conn.state.lock().await;
8193
8194                        let a = c1.clone();
8195                        st.stream_callbacks
8196                            .entry("path1::s".to_string())
8197                            .or_default()
8198                            .push(
8199                                (Box::new(move |_: &Value| {
8200                                    a.fetch_add(1, Ordering::SeqCst);
8201                                })
8202                                    as Box<dyn Fn(&Value) + Send + Sync>)
8203                                    .into(),
8204                            );
8205
8206                        let b = c2.clone();
8207                        st.stream_callbacks
8208                            .entry("path2::s".to_string())
8209                            .or_default()
8210                            .push(
8211                                (Box::new(move |_: &Value| {
8212                                    b.fetch_add(1, Ordering::SeqCst);
8213                                })
8214                                    as Box<dyn Fn(&Value) + Send + Sync>)
8215                                    .into(),
8216                            );
8217                    }
8218
8219                    let msg = json!({"stream":"s","data":42}).to_string();
8220                    ws.on_message(msg, conn.clone()).await;
8221
8222                    assert_eq!(c1.load(Ordering::SeqCst), 1);
8223                    assert_eq!(c2.load(Ordering::SeqCst), 0);
8224                });
8225            }
8226        }
8227
8228        mod get_reconnect_url {
8229            use super::*;
8230
8231            #[test]
8232            fn single_stream_reconnect_url() {
8233                TOKIO_SHARED_RT.block_on(async {
8234                    let ws: Arc<WebsocketStreams> =
8235                        create_websocket_streams(Some("ws://example.com"), None, None);
8236                    let c0 = ws.common.connection_pool[0].clone();
8237                    {
8238                        let mut map = ws.connection_streams.lock().await;
8239                        map.insert("s1".to_string(), c0.clone());
8240                    }
8241                    let url = ws.get_reconnect_url("default_url".into(), c0).await;
8242                    assert_eq!(url, "ws://example.com/stream?streams=s1");
8243                });
8244            }
8245
8246            #[test]
8247            fn multiple_streams_same_connection() {
8248                TOKIO_SHARED_RT.block_on(async {
8249                    let ws: Arc<WebsocketStreams> =
8250                        create_websocket_streams(Some("ws://example.com"), None, None);
8251                    let c0 = ws.common.connection_pool[0].clone();
8252                    {
8253                        let mut map = ws.connection_streams.lock().await;
8254                        map.insert("a".to_string(), c0.clone());
8255                        map.insert("b".to_string(), c0.clone());
8256                    }
8257                    let url = ws.get_reconnect_url("default_url".into(), c0).await;
8258                    let suffix = url
8259                        .strip_prefix("ws://example.com/stream?streams=")
8260                        .unwrap();
8261                    let parts: Vec<_> = suffix.split('&').next().unwrap().split('/').collect();
8262                    let set = parts.into_iter().collect::<std::collections::HashSet<_>>();
8263                    assert_eq!(set, ["a", "b"].iter().copied().collect());
8264                });
8265            }
8266
8267            #[test]
8268            fn reconnect_url_with_time_unit() {
8269                TOKIO_SHARED_RT.block_on(async {
8270                    let mut ws: Arc<WebsocketStreams> =
8271                        create_websocket_streams(Some("ws://example.com"), None, None);
8272                    Arc::get_mut(&mut ws).unwrap().configuration.time_unit =
8273                        Some(TimeUnit::Microsecond);
8274                    let c0 = ws.common.connection_pool[0].clone();
8275                    {
8276                        let mut map = ws.connection_streams.lock().await;
8277                        map.insert("x".to_string(), c0.clone());
8278                    }
8279                    let url = ws.get_reconnect_url("default_url".into(), c0).await;
8280                    assert_eq!(
8281                        url,
8282                        "ws://example.com/stream?streams=x&timeUnit=microsecond"
8283                    );
8284                });
8285            }
8286
8287            #[test]
8288            fn reconnect_url_uses_url_path_from_connection_state() {
8289                TOKIO_SHARED_RT.block_on(async {
8290                    let ws: Arc<WebsocketStreams> =
8291                        create_websocket_streams(Some("ws://example.com"), None, None);
8292                    let c0 = ws.common.connection_pool[0].clone();
8293
8294                    {
8295                        let mut st = c0.state.lock().await;
8296                        st.url_path = Some("path1".to_string());
8297                    }
8298
8299                    {
8300                        let mut map = ws.connection_streams.lock().await;
8301                        map.insert("path1::s1".to_string(), c0.clone());
8302                    }
8303
8304                    let url = ws.get_reconnect_url("default_url".into(), c0).await;
8305                    assert_eq!(url, "ws://example.com/path1/stream?streams=s1");
8306                });
8307            }
8308
8309            #[test]
8310            fn reconnect_url_strips_prefix_from_multiple_keys_with_url_path() {
8311                TOKIO_SHARED_RT.block_on(async {
8312                    let ws: Arc<WebsocketStreams> =
8313                        create_websocket_streams(Some("ws://example.com"), None, None);
8314                    let c0 = ws.common.connection_pool[0].clone();
8315
8316                    {
8317                        let mut st = c0.state.lock().await;
8318                        st.url_path = Some("path1".to_string());
8319                    }
8320
8321                    {
8322                        let mut map = ws.connection_streams.lock().await;
8323                        map.insert("path1::a".to_string(), c0.clone());
8324                        map.insert("path1::b".to_string(), c0.clone());
8325                    }
8326
8327                    let url = ws.get_reconnect_url("default_url".into(), c0).await;
8328
8329                    let suffix = url
8330                        .strip_prefix("ws://example.com/path1/stream?streams=")
8331                        .unwrap();
8332                    let parts: Vec<_> = suffix.split('&').next().unwrap().split('/').collect();
8333                    let set = parts.into_iter().collect::<std::collections::HashSet<_>>();
8334                    assert_eq!(set, ["a", "b"].iter().copied().collect());
8335                });
8336            }
8337
8338            #[test]
8339            fn reconnect_url_with_url_path_and_time_unit() {
8340                TOKIO_SHARED_RT.block_on(async {
8341                    let mut ws: Arc<WebsocketStreams> =
8342                        create_websocket_streams(Some("ws://example.com"), None, None);
8343                    Arc::get_mut(&mut ws).unwrap().configuration.time_unit =
8344                        Some(TimeUnit::Microsecond);
8345
8346                    let c0 = ws.common.connection_pool[0].clone();
8347
8348                    {
8349                        let mut st = c0.state.lock().await;
8350                        st.url_path = Some("path1".to_string());
8351                    }
8352
8353                    {
8354                        let mut map = ws.connection_streams.lock().await;
8355                        map.insert("path1::x".to_string(), c0.clone());
8356                    }
8357
8358                    let url = ws.get_reconnect_url("default_url".into(), c0).await;
8359                    assert_eq!(
8360                        url,
8361                        "ws://example.com/path1/stream?streams=x&timeUnit=microsecond"
8362                    );
8363                });
8364            }
8365
8366            #[test]
8367            fn reconnect_url_ignores_streams_from_other_connections_even_if_same_path_prefix() {
8368                TOKIO_SHARED_RT.block_on(async {
8369                    let ws: Arc<WebsocketStreams> =
8370                        create_websocket_streams(Some("ws://example.com"), None, None);
8371                    let c0 = ws.common.connection_pool[0].clone();
8372                    let c1 = ws.common.connection_pool[1].clone();
8373
8374                    {
8375                        let mut st = c0.state.lock().await;
8376                        st.url_path = Some("path1".to_string());
8377                    }
8378                    {
8379                        let mut st = c1.state.lock().await;
8380                        st.url_path = Some("path1".to_string());
8381                    }
8382
8383                    {
8384                        let mut map = ws.connection_streams.lock().await;
8385                        map.insert("path1::a".to_string(), c0.clone());
8386                        map.insert("path1::b".to_string(), c1.clone());
8387                    }
8388
8389                    let url = ws.get_reconnect_url("default_url".into(), c0).await;
8390
8391                    let suffix = url
8392                        .strip_prefix("ws://example.com/path1/stream?streams=")
8393                        .unwrap();
8394                    let parts: Vec<_> = suffix.split('&').next().unwrap().split('/').collect();
8395                    let set = parts.into_iter().collect::<std::collections::HashSet<_>>();
8396                    assert_eq!(set, ["a"].iter().copied().collect());
8397                });
8398            }
8399        }
8400    }
8401
8402    mod websocket_stream {
8403        use super::*;
8404
8405        mod on {
8406            use super::*;
8407
8408            #[test]
8409            fn registers_callback_and_stream_callback_for_websocket_streams() {
8410                TOKIO_SHARED_RT.block_on(async {
8411                    let ws_base = create_websocket_streams(Some("example.com"), None, None);
8412                    let stream_name = "s1".to_string();
8413                    let conn = ws_base.common.connection_pool[0].clone();
8414
8415                    let key = ws_base.stream_key(&stream_name, None);
8416
8417                    {
8418                        let mut map = ws_base.connection_streams.lock().await;
8419                        map.insert(key.clone(), conn.clone());
8420                    }
8421                    {
8422                        let mut state = conn.state.lock().await;
8423                        state.stream_callbacks.insert(key.clone(), Vec::new());
8424                    }
8425
8426                    let stream = Arc::new(WebsocketStream::<Value> {
8427                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8428                        stream_or_id: stream_name.clone(),
8429                        callback: Mutex::new(None),
8430                        url_path: None,
8431                        id: None,
8432                        _phantom: PhantomData,
8433                    });
8434
8435                    stream.on("message", |_| {}).await;
8436
8437                    let cb_guard = stream.callback.lock().await;
8438                    assert!(cb_guard.is_some());
8439
8440                    let cbs = {
8441                        let state = conn.state.lock().await;
8442                        state.stream_callbacks.get(&key).unwrap().clone()
8443                    };
8444                    assert_eq!(cbs.len(), 1);
8445                });
8446            }
8447
8448            #[test]
8449            fn message_twice_registers_two_wrappers_for_websocket_streams() {
8450                TOKIO_SHARED_RT.block_on(async {
8451                    let ws_base = create_websocket_streams(Some("example.com"), None, None);
8452                    let stream_name = "s2".to_string();
8453                    let conn = ws_base.common.connection_pool[0].clone();
8454
8455                    let key = ws_base.stream_key(&stream_name, None);
8456
8457                    {
8458                        let mut map = ws_base.connection_streams.lock().await;
8459                        map.insert(key.clone(), conn.clone());
8460                    }
8461                    {
8462                        let mut state = conn.state.lock().await;
8463                        state.stream_callbacks.insert(key.clone(), Vec::new());
8464                    }
8465
8466                    let stream = Arc::new(WebsocketStream::<Value> {
8467                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8468                        stream_or_id: stream_name.clone(),
8469                        url_path: None,
8470                        callback: Mutex::new(None),
8471                        id: None,
8472                        _phantom: PhantomData,
8473                    });
8474
8475                    stream.on("message", |_| {}).await;
8476                    stream.on("message", |_| {}).await;
8477
8478                    let state = conn.state.lock().await;
8479                    let callbacks = state.stream_callbacks.get(&key).unwrap();
8480                    assert_eq!(callbacks.len(), 2);
8481                });
8482            }
8483
8484            #[test]
8485            fn ignores_non_message_event_for_websocket_streams() {
8486                TOKIO_SHARED_RT.block_on(async {
8487                    let ws_base = create_websocket_streams(Some("example.com"), None, None);
8488                    let stream = Arc::new(WebsocketStream::<Value> {
8489                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8490                        stream_or_id: "s".into(),
8491                        url_path: None,
8492                        callback: Mutex::new(None),
8493                        id: None,
8494                        _phantom: PhantomData,
8495                    });
8496                    stream.on("open", |_| {}).await;
8497                    let guard = stream.callback.lock().await;
8498                    assert!(guard.is_none());
8499                });
8500            }
8501
8502            #[test]
8503            fn registers_callback_and_stream_callback_for_websocket_api() {
8504                TOKIO_SHARED_RT.block_on(async {
8505                    let ws_base = create_websocket_api(None, None, None);
8506
8507                    {
8508                        let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
8509                        stream_callbacks.insert("id1".to_string(), Vec::new());
8510                    }
8511
8512                    let stream = Arc::new(WebsocketStream::<Value> {
8513                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
8514                        stream_or_id: "id1".to_string(),
8515                        url_path: None,
8516                        callback: Mutex::new(None),
8517                        id: None,
8518                        _phantom: PhantomData,
8519                    });
8520
8521                    let called = Arc::new(Mutex::new(false));
8522                    let called_clone = called.clone();
8523                    stream
8524                        .on("message", move |v: Value| {
8525                            let mut lock = called_clone.blocking_lock();
8526                            *lock = v == Value::String("x".into());
8527                        })
8528                        .await;
8529
8530                    let cb_guard = stream.callback.lock().await;
8531                    assert!(cb_guard.is_some());
8532
8533                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
8534                    let callbacks = stream_callbacks.get("id1").unwrap();
8535                    assert_eq!(callbacks.len(), 1);
8536                });
8537            }
8538
8539            #[test]
8540            fn message_twice_registers_two_wrappers_for_websocket_api() {
8541                TOKIO_SHARED_RT.block_on(async {
8542                    let ws_base = create_websocket_api(None, None, None);
8543
8544                    let stream = Arc::new(WebsocketStream::<Value> {
8545                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
8546                        stream_or_id: "id2".to_string(),
8547                        url_path: None,
8548                        callback: Mutex::new(None),
8549                        id: None,
8550                        _phantom: PhantomData,
8551                    });
8552
8553                    stream.on("message", |_| {}).await;
8554                    stream.on("message", |_| {}).await;
8555
8556                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
8557                    let callbacks = stream_callbacks.get("id2").unwrap();
8558                    assert_eq!(callbacks.len(), 2);
8559                });
8560            }
8561
8562            #[test]
8563            fn ignores_non_message_event_for_websocket_api() {
8564                TOKIO_SHARED_RT.block_on(async {
8565                    let ws_base = create_websocket_api(None, None, None);
8566
8567                    let stream = Arc::new(WebsocketStream::<Value> {
8568                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
8569                        stream_or_id: "id3".into(),
8570                        url_path: None,
8571                        callback: Mutex::new(None),
8572                        id: None,
8573                        _phantom: PhantomData,
8574                    });
8575
8576                    stream.on("open", |_| {}).await;
8577
8578                    let guard = stream.callback.lock().await;
8579                    assert!(guard.is_none());
8580
8581                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
8582                    assert!(stream_callbacks.get("id3").is_none());
8583                    assert!(stream_callbacks.is_empty());
8584                });
8585            }
8586
8587            #[test]
8588            fn registers_callback_for_websocket_streams_with_url_path() {
8589                TOKIO_SHARED_RT.block_on(async {
8590                    let ws_base = create_websocket_streams(Some("example.com"), None, None);
8591                    let stream_name = "s1".to_string();
8592                    let conn = ws_base.common.connection_pool[0].clone();
8593
8594                    let key = ws_base.stream_key(&stream_name, Some("path1"));
8595
8596                    {
8597                        let mut map = ws_base.connection_streams.lock().await;
8598                        map.insert(key.clone(), conn.clone());
8599                    }
8600                    {
8601                        let mut state = conn.state.lock().await;
8602                        state.stream_callbacks.insert(key.clone(), Vec::new());
8603                    }
8604
8605                    let stream = Arc::new(WebsocketStream::<Value> {
8606                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8607                        stream_or_id: stream_name.clone(),
8608                        url_path: Some("path1".to_string()),
8609                        callback: Mutex::new(None),
8610                        id: None,
8611                        _phantom: PhantomData,
8612                    });
8613
8614                    stream.on("message", |_| {}).await;
8615
8616                    let cb_guard = stream.callback.lock().await;
8617                    assert!(cb_guard.is_some());
8618
8619                    let callbacks = {
8620                        let state = conn.state.lock().await;
8621                        state.stream_callbacks.get(&key).unwrap().clone()
8622                    };
8623                    assert_eq!(callbacks.len(), 1);
8624                });
8625            }
8626
8627            #[test]
8628            fn url_path_routes_callback_to_correct_key_when_same_stream_name_used() {
8629                TOKIO_SHARED_RT.block_on(async {
8630                    let ws_base = create_websocket_streams(Some("example.com"), None, None);
8631                    let stream_name = "s1".to_string();
8632                    let conn = ws_base.common.connection_pool[0].clone();
8633
8634                    let key1 = ws_base.stream_key(&stream_name, Some("path1"));
8635                    let key2 = ws_base.stream_key(&stream_name, Some("path2"));
8636
8637                    {
8638                        let mut map = ws_base.connection_streams.lock().await;
8639                        map.insert(key1.clone(), conn.clone());
8640                        map.insert(key2.clone(), conn.clone());
8641                    }
8642                    {
8643                        let mut state = conn.state.lock().await;
8644                        state.stream_callbacks.insert(key1.clone(), Vec::new());
8645                        state.stream_callbacks.insert(key2.clone(), Vec::new());
8646                    }
8647
8648                    let stream_path1 = Arc::new(WebsocketStream::<Value> {
8649                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8650                        stream_or_id: stream_name.clone(),
8651                        url_path: Some("path1".to_string()),
8652                        callback: Mutex::new(None),
8653                        id: None,
8654                        _phantom: PhantomData,
8655                    });
8656
8657                    let stream_path2 = Arc::new(WebsocketStream::<Value> {
8658                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8659                        stream_or_id: stream_name.clone(),
8660                        url_path: Some("path2".to_string()),
8661                        callback: Mutex::new(None),
8662                        id: None,
8663                        _phantom: PhantomData,
8664                    });
8665
8666                    stream_path1.on("message", |_| {}).await;
8667                    stream_path2.on("message", |_| {}).await;
8668
8669                    let state = conn.state.lock().await;
8670                    assert_eq!(state.stream_callbacks.get(&key1).unwrap().len(), 1);
8671                    assert_eq!(state.stream_callbacks.get(&key2).unwrap().len(), 1);
8672                });
8673            }
8674        }
8675
8676        mod on_message {
8677            use super::*;
8678
8679            #[test]
8680            fn on_message_registers_callback_for_websocket_streams() {
8681                TOKIO_SHARED_RT.block_on(async {
8682                    let ws_base = create_websocket_streams(Some("example.com"), None, None);
8683                    let stream_name = "s".to_string();
8684                    let conn = ws_base.common.connection_pool[0].clone();
8685                    {
8686                        let mut map = ws_base.connection_streams.lock().await;
8687                        map.insert(stream_name.clone(), conn.clone());
8688                    }
8689                    {
8690                        let mut state = conn.state.lock().await;
8691                        state
8692                            .stream_callbacks
8693                            .insert(stream_name.clone(), Vec::new());
8694                    }
8695                    let stream = Arc::new(WebsocketStream::<Value> {
8696                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8697                        stream_or_id: stream_name.clone(),
8698                        url_path: None,
8699                        callback: Mutex::new(None),
8700                        id: None,
8701                        _phantom: PhantomData,
8702                    });
8703                    stream.on_message(|_v| {});
8704                    let callbacks = &conn.state.lock().await.stream_callbacks[&stream_name];
8705                    assert_eq!(callbacks.len(), 1);
8706                });
8707            }
8708
8709            #[test]
8710            fn on_message_twice_registers_two_callbacks_for_websocket_streams() {
8711                TOKIO_SHARED_RT.block_on(async {
8712                    let ws_base = create_websocket_streams(Some("example.com"), None, None);
8713                    let stream_name = "s".to_string();
8714                    let conn = ws_base.common.connection_pool[0].clone();
8715                    {
8716                        let mut map = ws_base.connection_streams.lock().await;
8717                        map.insert(stream_name.clone(), conn.clone());
8718                    }
8719                    {
8720                        let mut state = conn.state.lock().await;
8721                        state
8722                            .stream_callbacks
8723                            .insert(stream_name.clone(), Vec::new());
8724                    }
8725                    let stream = Arc::new(WebsocketStream::<Value> {
8726                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8727                        stream_or_id: stream_name.clone(),
8728                        url_path: None,
8729                        callback: Mutex::new(None),
8730                        id: None,
8731                        _phantom: PhantomData,
8732                    });
8733                    stream.on_message(|_v| {});
8734                    stream.on_message(|_v| {});
8735                    let callbacks = &conn.state.lock().await.stream_callbacks[&stream_name];
8736                    assert_eq!(callbacks.len(), 2);
8737                });
8738            }
8739
8740            #[test]
8741            fn on_message_registers_callback_for_websocket_api() {
8742                TOKIO_SHARED_RT.block_on(async {
8743                    let ws_base = create_websocket_api(None, None, None);
8744                    let identifier = "id1".to_string();
8745
8746                    let stream = Arc::new(WebsocketStream::<Value> {
8747                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
8748                        stream_or_id: identifier.clone(),
8749                        url_path: None,
8750                        callback: Mutex::new(None),
8751                        id: None,
8752                        _phantom: PhantomData,
8753                    });
8754
8755                    stream.on_message(|_v: Value| {});
8756
8757                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
8758                    let callbacks = stream_callbacks.get(&identifier).unwrap();
8759                    assert_eq!(callbacks.len(), 1);
8760                });
8761            }
8762
8763            #[test]
8764            fn on_message_twice_registers_two_callbacks_for_websocket_api() {
8765                TOKIO_SHARED_RT.block_on(async {
8766                    let ws_base = create_websocket_api(None, None, None);
8767                    let identifier = "id2".to_string();
8768
8769                    let stream = Arc::new(WebsocketStream::<Value> {
8770                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
8771                        stream_or_id: identifier.clone(),
8772                        url_path: None,
8773                        callback: Mutex::new(None),
8774                        id: None,
8775                        _phantom: PhantomData,
8776                    });
8777
8778                    stream.on_message(|_v: Value| {});
8779                    stream.on_message(|_v: Value| {});
8780
8781                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
8782                    let callbacks = stream_callbacks.get(&identifier).unwrap();
8783                    assert_eq!(callbacks.len(), 2);
8784                });
8785            }
8786        }
8787
8788        mod unsubscribe {
8789            use super::*;
8790
8791            #[test]
8792            fn without_callback_does_nothing() {
8793                TOKIO_SHARED_RT.block_on(async {
8794                    let ws_base = create_websocket_streams(Some("example.com"), None, None);
8795                    let stream_name = "s1".to_string();
8796                    let conn = ws_base.common.connection_pool[0].clone();
8797                    {
8798                        let mut map = ws_base.connection_streams.lock().await;
8799                        map.insert(stream_name.clone(), conn.clone());
8800                    }
8801                    let mut state = conn.state.lock().await;
8802                    state.stream_callbacks.insert(stream_name.clone(), vec![]);
8803                    drop(state);
8804                    let stream = Arc::new(WebsocketStream::<Value> {
8805                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8806                        stream_or_id: stream_name.clone(),
8807                        url_path: None,
8808                        callback: Mutex::new(None),
8809                        id: None,
8810                        _phantom: PhantomData,
8811                    });
8812                    stream.unsubscribe().await;
8813                    let state = conn.state.lock().await;
8814                    assert!(state.stream_callbacks.contains_key(&stream_name));
8815                });
8816            }
8817
8818            #[test]
8819            fn removes_registered_callback_and_clears_state() {
8820                TOKIO_SHARED_RT.block_on(async {
8821                    let ws_base = create_websocket_streams(Some("example.com"), None, None);
8822                    let stream_name = "s2".to_string();
8823                    let conn = ws_base.common.connection_pool[0].clone();
8824                    {
8825                        let mut map = ws_base.connection_streams.lock().await;
8826                        map.insert(stream_name.clone(), conn.clone());
8827                    }
8828                    {
8829                        let mut state = conn.state.lock().await;
8830                        state
8831                            .stream_callbacks
8832                            .insert(stream_name.clone(), Vec::new());
8833                    }
8834                    let stream = Arc::new(WebsocketStream::<Value> {
8835                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
8836                        stream_or_id: stream_name.clone(),
8837                        url_path: None,
8838                        callback: Mutex::new(None),
8839                        id: None,
8840                        _phantom: PhantomData,
8841                    });
8842                    stream.on("message", |_| {}).await;
8843                    {
8844                        let guard = stream.callback.lock().await;
8845                        assert!(guard.is_some());
8846                    }
8847                    stream.unsubscribe().await;
8848                    sleep(Duration::from_millis(10)).await;
8849                    let guard = stream.callback.lock().await;
8850                    assert!(guard.is_none());
8851                    let state = conn.state.lock().await;
8852                    assert!(
8853                        state
8854                            .stream_callbacks
8855                            .get(&stream_name)
8856                            .is_none_or(std::vec::Vec::is_empty)
8857                    );
8858                });
8859            }
8860
8861            #[test]
8862            fn without_callback_does_nothing_for_websocket_api() {
8863                TOKIO_SHARED_RT.block_on(async {
8864                    let ws_base = create_websocket_api(None, None, None);
8865                    let identifier = "id1".to_string();
8866
8867                    {
8868                        let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
8869                        stream_callbacks.insert(identifier.clone(), Vec::new());
8870                    }
8871
8872                    let stream = Arc::new(WebsocketStream::<Value> {
8873                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
8874                        stream_or_id: identifier.clone(),
8875                        url_path: None,
8876                        callback: Mutex::new(None),
8877                        id: None,
8878                        _phantom: PhantomData,
8879                    });
8880
8881                    stream.unsubscribe().await;
8882
8883                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
8884                    assert!(stream_callbacks.contains_key(&identifier));
8885                    let callbacks = stream_callbacks.get(&identifier).unwrap();
8886                    assert!(callbacks.is_empty());
8887                });
8888            }
8889
8890            #[test]
8891            fn removes_registered_callback_and_clears_state_for_websocket_api() {
8892                TOKIO_SHARED_RT.block_on(async {
8893                    let ws_base = create_websocket_api(None, None, None);
8894                    let identifier = "id2".to_string();
8895
8896                    {
8897                        let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
8898                        stream_callbacks.insert(identifier.clone(), Vec::new());
8899                    }
8900
8901                    let stream = Arc::new(WebsocketStream::<Value> {
8902                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
8903                        stream_or_id: identifier.clone(),
8904                        url_path: None,
8905                        callback: Mutex::new(None),
8906                        id: None,
8907                        _phantom: PhantomData,
8908                    });
8909
8910                    stream.on("message", |_| {}).await;
8911
8912                    {
8913                        let stream_callbacks = ws_base.stream_callbacks.lock().await;
8914                        let callbacks = stream_callbacks
8915                            .get(&identifier)
8916                            .expect("Entry for 'id2' should exist");
8917                        assert_eq!(callbacks.len(), 1);
8918                    }
8919
8920                    stream.unsubscribe().await;
8921
8922                    {
8923                        let guard = stream.callback.lock().await;
8924                        assert!(guard.is_none());
8925                    }
8926
8927                    {
8928                        let stream_callbacks = ws_base.stream_callbacks.lock().await;
8929                        let callbacks = stream_callbacks
8930                            .get(&identifier)
8931                            .expect("Entry for 'id2' should still exist");
8932                        assert!(callbacks.is_empty());
8933                    }
8934                });
8935            }
8936        }
8937    }
8938
8939    mod create_stream_handler {
8940        use super::*;
8941
8942        #[test]
8943        fn create_stream_handler_without_id_registers_stream() {
8944            TOKIO_SHARED_RT.block_on(async {
8945                let ws = create_websocket_streams(Some("ws://example.com"), None, None);
8946                let stream_name = "foo".to_string();
8947                let handler = create_stream_handler::<serde_json::Value>(
8948                    WebsocketBase::WebsocketStreams(ws.clone()),
8949                    stream_name.clone(),
8950                    None,
8951                    None,
8952                )
8953                .await;
8954                assert_eq!(handler.stream_or_id, stream_name);
8955                assert!(handler.id.is_none());
8956                let map = ws.connection_streams.lock().await;
8957                assert!(map.contains_key(&stream_name));
8958            });
8959        }
8960
8961        #[test]
8962        fn create_stream_handler_with_custom_string_id_registers_stream_and_id() {
8963            TOKIO_SHARED_RT.block_on(async {
8964                let ws = create_websocket_streams(Some("ws://example.com"), None, None);
8965                let stream_name = "bar".to_string();
8966                let custom_id = StreamId::from("my-custom-id".to_string());
8967                let handler = create_stream_handler::<serde_json::Value>(
8968                    WebsocketBase::WebsocketStreams(ws.clone()),
8969                    stream_name.clone(),
8970                    Some(custom_id.clone()),
8971                    None,
8972                )
8973                .await;
8974                assert_eq!(handler.stream_or_id, stream_name);
8975                assert_eq!(handler.id, Some(custom_id));
8976                let map = ws.connection_streams.lock().await;
8977                assert!(map.contains_key(&stream_name));
8978            });
8979        }
8980
8981        #[test]
8982        fn create_stream_handler_with_custom_integer_id_registers_stream_and_id() {
8983            TOKIO_SHARED_RT.block_on(async {
8984                let ws = create_websocket_streams(Some("ws://example.com"), None, None);
8985                let stream_name = "bar".to_string();
8986                let custom_id = StreamId::from(123u32);
8987                let handler = create_stream_handler::<serde_json::Value>(
8988                    WebsocketBase::WebsocketStreams(ws.clone()),
8989                    stream_name.clone(),
8990                    Some(custom_id.clone()),
8991                    None,
8992                )
8993                .await;
8994                assert_eq!(handler.stream_or_id, stream_name);
8995                assert_eq!(handler.id, Some(custom_id));
8996                let map = ws.connection_streams.lock().await;
8997                assert!(map.contains_key(&stream_name));
8998            });
8999        }
9000
9001        #[test]
9002        fn create_stream_handler_without_id_registers_api_stream() {
9003            TOKIO_SHARED_RT.block_on(async {
9004                let ws_base = create_websocket_api(None, None, None);
9005                let identifier = "foo-api".to_string();
9006
9007                let handler = create_stream_handler::<Value>(
9008                    WebsocketBase::WebsocketApi(ws_base.clone()),
9009                    identifier.clone(),
9010                    None,
9011                    None,
9012                )
9013                .await;
9014
9015                assert_eq!(handler.stream_or_id, identifier);
9016                assert!(handler.id.is_none());
9017            });
9018        }
9019
9020        #[test]
9021        fn create_stream_handler_with_custom_string_id_registers_api_stream_and_id() {
9022            TOKIO_SHARED_RT.block_on(async {
9023                let ws_base = create_websocket_api(None, None, None);
9024                let identifier = "bar-api".to_string();
9025                let custom_id = StreamId::from("custom-123".to_string());
9026
9027                let handler = create_stream_handler::<Value>(
9028                    WebsocketBase::WebsocketApi(ws_base.clone()),
9029                    identifier.clone(),
9030                    Some(custom_id.clone()),
9031                    None,
9032                )
9033                .await;
9034
9035                assert_eq!(handler.stream_or_id, identifier);
9036                assert_eq!(handler.id, Some(custom_id));
9037            });
9038        }
9039
9040        #[test]
9041        fn create_stream_handler_with_custom_integer_id_registers_api_stream_and_id() {
9042            TOKIO_SHARED_RT.block_on(async {
9043                let ws_base = create_websocket_api(None, None, None);
9044                let identifier = "bar-api".to_string();
9045                let custom_id = StreamId::from(123u32);
9046
9047                let handler = create_stream_handler::<Value>(
9048                    WebsocketBase::WebsocketApi(ws_base.clone()),
9049                    identifier.clone(),
9050                    Some(custom_id.clone()),
9051                    None,
9052                )
9053                .await;
9054
9055                assert_eq!(handler.stream_or_id, identifier);
9056                assert_eq!(handler.id, Some(custom_id));
9057            });
9058        }
9059
9060        #[test]
9061        fn websocket_streams_without_url_path_registers_stream_key() {
9062            TOKIO_SHARED_RT.block_on(async {
9063                let ws = create_websocket_streams(Some("ws://example.com"), None, None);
9064                let stream_name = "foo".to_string();
9065
9066                let handler = create_stream_handler::<Value>(
9067                    WebsocketBase::WebsocketStreams(ws.clone()),
9068                    stream_name.clone(),
9069                    None,
9070                    None,
9071                )
9072                .await;
9073
9074                assert_eq!(handler.stream_or_id, stream_name);
9075                assert!(handler.id.is_none());
9076
9077                let map = ws.connection_streams.lock().await;
9078                assert!(map.contains_key("foo"));
9079            });
9080        }
9081
9082        #[test]
9083        fn websocket_streams_with_url_path_registers_prefixed_stream_key() {
9084            TOKIO_SHARED_RT.block_on(async {
9085                let ws = create_websocket_streams(Some("ws://example.com"), None, None);
9086
9087                {
9088                    let conn = ws.common.connection_pool[0].clone();
9089                    let mut st = conn.state.lock().await;
9090                    st.url_path = Some("path1".to_string());
9091                }
9092
9093                let stream_name = "foo".to_string();
9094
9095                let handler = create_stream_handler::<Value>(
9096                    WebsocketBase::WebsocketStreams(ws.clone()),
9097                    stream_name.clone(),
9098                    None,
9099                    Some("path1".to_string()),
9100                )
9101                .await;
9102
9103                assert_eq!(handler.stream_or_id, stream_name);
9104                assert!(handler.id.is_none());
9105
9106                let map = ws.connection_streams.lock().await;
9107                assert!(map.contains_key("path1::foo"));
9108            });
9109        }
9110
9111        #[test]
9112        fn websocket_streams_with_custom_id_preserves_id_and_registers_prefixed_key() {
9113            TOKIO_SHARED_RT.block_on(async {
9114                let ws = create_websocket_streams(Some("ws://example.com"), None, None);
9115
9116                {
9117                    let conn = ws.common.connection_pool[0].clone();
9118                    let mut st = conn.state.lock().await;
9119                    st.url_path = Some("path1".to_string());
9120                }
9121
9122                let stream_name = "bar".to_string();
9123                let custom_id = StreamId::from("my-custom-id".to_string());
9124
9125                let handler = create_stream_handler::<Value>(
9126                    WebsocketBase::WebsocketStreams(ws.clone()),
9127                    stream_name.clone(),
9128                    Some(custom_id.clone()),
9129                    Some("path1".to_string()),
9130                )
9131                .await;
9132
9133                assert_eq!(handler.stream_or_id, stream_name);
9134                assert_eq!(handler.id, Some(custom_id));
9135
9136                let map = ws.connection_streams.lock().await;
9137                assert!(map.contains_key("path1::bar"));
9138            });
9139        }
9140
9141        #[test]
9142        fn websocket_api_does_not_register_stream_in_connection_map() {
9143            TOKIO_SHARED_RT.block_on(async {
9144                let ws_base = create_websocket_api(None, None, None);
9145                let identifier = "foo-api".to_string();
9146
9147                let handler = create_stream_handler::<Value>(
9148                    WebsocketBase::WebsocketApi(ws_base.clone()),
9149                    identifier.clone(),
9150                    None,
9151                    Some("path1".to_string()),
9152                )
9153                .await;
9154
9155                assert_eq!(handler.stream_or_id, identifier);
9156                assert!(handler.id.is_none());
9157            });
9158        }
9159    }
9160
9161    mod websocket_connection_failure_reason {
9162        use super::*;
9163        use std::io::{Error as IoError, ErrorKind};
9164        use tokio_tungstenite::tungstenite::Error as TungsteniteError;
9165
9166        #[test]
9167        fn from_tungstenite_error_classifies_connection_closed() {
9168            let error = TungsteniteError::ConnectionClosed;
9169            let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9170            assert!(matches!(
9171                reason,
9172                WebsocketConnectionFailureReason::ConnectionReset
9173            ));
9174            assert!(reason.should_reconnect());
9175        }
9176
9177        #[test]
9178        fn from_tungstenite_error_classifies_already_closed() {
9179            let error = TungsteniteError::AlreadyClosed;
9180            let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9181            assert!(matches!(
9182                reason,
9183                WebsocketConnectionFailureReason::ConnectionReset
9184            ));
9185            assert!(reason.should_reconnect());
9186        }
9187
9188        #[test]
9189        fn from_tungstenite_error_classifies_io_errors() {
9190            // Test ConnectionReset
9191            let io_error = IoError::new(ErrorKind::ConnectionReset, "connection reset");
9192            let error = TungsteniteError::Io(io_error);
9193            let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9194            assert!(matches!(
9195                reason,
9196                WebsocketConnectionFailureReason::ConnectionReset
9197            ));
9198            assert!(reason.should_reconnect());
9199
9200            // Test ConnectionAborted
9201            let io_error = IoError::new(ErrorKind::ConnectionAborted, "connection aborted");
9202            let error = TungsteniteError::Io(io_error);
9203            let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9204            assert!(matches!(
9205                reason,
9206                WebsocketConnectionFailureReason::ConnectionReset
9207            ));
9208            assert!(reason.should_reconnect());
9209
9210            // Test TimedOut
9211            let io_error = IoError::new(ErrorKind::TimedOut, "timed out");
9212            let error = TungsteniteError::Io(io_error);
9213            let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9214            assert!(matches!(
9215                reason,
9216                WebsocketConnectionFailureReason::NetworkInterruption
9217            ));
9218            assert!(reason.should_reconnect());
9219
9220            // Test UnexpectedEof
9221            let io_error = IoError::new(ErrorKind::UnexpectedEof, "unexpected eof");
9222            let error = TungsteniteError::Io(io_error);
9223            let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9224            assert!(matches!(
9225                reason,
9226                WebsocketConnectionFailureReason::StreamEnded
9227            ));
9228            assert!(reason.should_reconnect());
9229
9230            // Test PermissionDenied
9231            let io_error = IoError::new(ErrorKind::PermissionDenied, "permission denied");
9232            let error = TungsteniteError::Io(io_error);
9233            let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9234            assert!(matches!(
9235                reason,
9236                WebsocketConnectionFailureReason::AuthenticationFailure
9237            ));
9238            assert!(!reason.should_reconnect());
9239
9240            // Test other IO errors default to NetworkInterruption
9241            let io_error = IoError::other("other error");
9242            let error = TungsteniteError::Io(io_error);
9243            let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9244            assert!(matches!(
9245                reason,
9246                WebsocketConnectionFailureReason::NetworkInterruption
9247            ));
9248            assert!(reason.should_reconnect());
9249        }
9250
9251        #[test]
9252        fn from_tungstenite_error_classifies_protocol_errors() {
9253            // Protocol error -> ProtocolViolation
9254            use tokio_tungstenite::tungstenite::error::ProtocolError;
9255            let protocol_error = ProtocolError::ResetWithoutClosingHandshake;
9256            let error = TungsteniteError::Protocol(protocol_error);
9257            let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9258            assert!(matches!(
9259                reason,
9260                WebsocketConnectionFailureReason::ProtocolViolation
9261            ));
9262            assert!(!reason.should_reconnect());
9263
9264            // UTF8 error -> ProtocolViolation
9265            let error = TungsteniteError::Utf8;
9266            let reason = WebsocketConnectionFailureReason::from_tungstenite_error(&error);
9267            assert!(matches!(
9268                reason,
9269                WebsocketConnectionFailureReason::ProtocolViolation
9270            ));
9271            assert!(!reason.should_reconnect());
9272        }
9273
9274        #[test]
9275        fn from_close_code_classifies_standard_codes() {
9276            // Normal closure
9277            let reason = WebsocketConnectionFailureReason::from_close_code(1000, false);
9278            assert!(matches!(
9279                reason,
9280                WebsocketConnectionFailureReason::NormalClose
9281            ));
9282            assert!(!reason.should_reconnect());
9283
9284            // Going away (server restart) -> ServerTemporaryError
9285            let reason = WebsocketConnectionFailureReason::from_close_code(1001, false);
9286            assert!(matches!(
9287                reason,
9288                WebsocketConnectionFailureReason::ServerTemporaryError
9289            ));
9290            assert!(reason.should_reconnect());
9291
9292            // Protocol error -> ProtocolViolation
9293            let reason = WebsocketConnectionFailureReason::from_close_code(1002, false);
9294            assert!(matches!(
9295                reason,
9296                WebsocketConnectionFailureReason::ProtocolViolation
9297            ));
9298            assert!(!reason.should_reconnect());
9299
9300            // Abnormal closure -> UnexpectedClose
9301            let reason = WebsocketConnectionFailureReason::from_close_code(1006, false);
9302            assert!(matches!(
9303                reason,
9304                WebsocketConnectionFailureReason::UnexpectedClose
9305            ));
9306            assert!(reason.should_reconnect());
9307
9308            // Policy violation -> PermanentServerError
9309            let reason = WebsocketConnectionFailureReason::from_close_code(1008, false);
9310            assert!(matches!(
9311                reason,
9312                WebsocketConnectionFailureReason::PermanentServerError
9313            ));
9314            assert!(!reason.should_reconnect());
9315
9316            // Server error -> ServerTemporaryError
9317            let reason = WebsocketConnectionFailureReason::from_close_code(1011, false);
9318            assert!(matches!(
9319                reason,
9320                WebsocketConnectionFailureReason::ServerTemporaryError
9321            ));
9322            assert!(reason.should_reconnect());
9323
9324            // TLS handshake failure -> ConfigurationError
9325            let reason = WebsocketConnectionFailureReason::from_close_code(1015, false);
9326            assert!(matches!(
9327                reason,
9328                WebsocketConnectionFailureReason::ConfigurationError
9329            ));
9330            assert!(!reason.should_reconnect());
9331
9332            // Business/application errors -> PermanentServerError
9333            let reason = WebsocketConnectionFailureReason::from_close_code(4000, false);
9334            assert!(matches!(
9335                reason,
9336                WebsocketConnectionFailureReason::PermanentServerError
9337            ));
9338            assert!(!reason.should_reconnect());
9339
9340            let reason = WebsocketConnectionFailureReason::from_close_code(4999, false);
9341            assert!(matches!(
9342                reason,
9343                WebsocketConnectionFailureReason::PermanentServerError
9344            ));
9345            assert!(!reason.should_reconnect());
9346
9347            // Unknown codes default to UnexpectedClose
9348            let reason = WebsocketConnectionFailureReason::from_close_code(9999, false);
9349            assert!(matches!(
9350                reason,
9351                WebsocketConnectionFailureReason::UnexpectedClose
9352            ));
9353            assert!(reason.should_reconnect());
9354        }
9355
9356        #[test]
9357        fn from_close_code_handles_user_initiated() {
9358            // Any code with user_initiated=true should return UserInitiatedClose
9359            let reason = WebsocketConnectionFailureReason::from_close_code(1000, true);
9360            assert!(matches!(
9361                reason,
9362                WebsocketConnectionFailureReason::UserInitiatedClose
9363            ));
9364            assert!(!reason.should_reconnect());
9365
9366            let reason = WebsocketConnectionFailureReason::from_close_code(1006, true);
9367            assert!(matches!(
9368                reason,
9369                WebsocketConnectionFailureReason::UserInitiatedClose
9370            ));
9371            assert!(!reason.should_reconnect());
9372
9373            let reason = WebsocketConnectionFailureReason::from_close_code(4000, true);
9374            assert!(matches!(
9375                reason,
9376                WebsocketConnectionFailureReason::UserInitiatedClose
9377            ));
9378            assert!(!reason.should_reconnect());
9379        }
9380
9381        #[test]
9382        fn should_reconnect_logic() {
9383            // Reconnectable failures
9384            assert!(WebsocketConnectionFailureReason::NetworkInterruption.should_reconnect());
9385            assert!(WebsocketConnectionFailureReason::ConnectionReset.should_reconnect());
9386            assert!(WebsocketConnectionFailureReason::ServerTemporaryError.should_reconnect());
9387            assert!(WebsocketConnectionFailureReason::UnexpectedClose.should_reconnect());
9388            assert!(WebsocketConnectionFailureReason::StreamEnded.should_reconnect());
9389
9390            // Non-reconnectable failures
9391            assert!(!WebsocketConnectionFailureReason::AuthenticationFailure.should_reconnect());
9392            assert!(!WebsocketConnectionFailureReason::ProtocolViolation.should_reconnect());
9393            assert!(!WebsocketConnectionFailureReason::ConfigurationError.should_reconnect());
9394            assert!(!WebsocketConnectionFailureReason::UserInitiatedClose.should_reconnect());
9395            assert!(!WebsocketConnectionFailureReason::PermanentServerError.should_reconnect());
9396            assert!(!WebsocketConnectionFailureReason::NormalClose.should_reconnect());
9397        }
9398
9399        #[test]
9400        fn debug_and_clone_work() {
9401            let reason = WebsocketConnectionFailureReason::NetworkInterruption;
9402            let cloned = reason;
9403            let debug_str = format!("{:?}", reason);
9404
9405            assert!(matches!(
9406                cloned,
9407                WebsocketConnectionFailureReason::NetworkInterruption
9408            ));
9409            assert!(debug_str.contains("NetworkInterruption"));
9410        }
9411    }
9412}