binance_sdk/common/
websocket.rs

1use async_trait::async_trait;
2use flate2::read::ZlibDecoder;
3use futures::{SinkExt, StreamExt, stream::FuturesUnordered};
4use http::header::USER_AGENT;
5use regex::Regex;
6use serde::de::DeserializeOwned;
7use serde_json::{Value, json};
8use std::{
9    collections::{BTreeMap, HashMap, VecDeque},
10    io::Read,
11    marker::PhantomData,
12    mem::take,
13    sync::{
14        Arc, LazyLock,
15        atomic::{AtomicUsize, Ordering},
16    },
17    time::Duration,
18};
19use tokio::{
20    net::TcpStream,
21    select, spawn,
22    sync::{
23        Mutex, Notify, broadcast,
24        mpsc::{Receiver, Sender, UnboundedSender, channel, unbounded_channel},
25        oneshot,
26    },
27    task::JoinHandle,
28    time::{sleep, timeout},
29};
30use tokio_tungstenite::{
31    Connector, MaybeTlsStream, WebSocketStream, connect_async_tls_with_config,
32    tungstenite::{
33        Message,
34        client::IntoClientRequest,
35        protocol::{CloseFrame, WebSocketConfig, frame::coding::CloseCode},
36    },
37};
38use tokio_util::time::DelayQueue;
39use tracing::{debug, error, info, warn};
40
41use crate::common::utils::{remove_empty_value, sort_object_params};
42
43use super::{
44    config::{AgentConnector, ConfigurationWebsocketApi, ConfigurationWebsocketStreams},
45    errors::WebsocketError,
46    models::{WebsocketApiResponse, WebsocketEvent, WebsocketMode},
47    utils::{get_timestamp, random_string, validate_time_unit},
48};
49
50static ID_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^[0-9a-f]{32}$").unwrap());
51
52pub type WebSocketClient = WebSocketStream<MaybeTlsStream<TcpStream>>;
53
54const MAX_CONN_DURATION: Duration = Duration::from_secs(23 * 60 * 60);
55
56pub struct Subscription {
57    handle: JoinHandle<()>,
58}
59
60impl Subscription {
61    /// Cancels the ongoing WebSocket event subscription and stops the event processing task.
62    ///
63    /// This method aborts the background task responsible for receiving and processing
64    /// WebSocket events, effectively unsubscribing from further event notifications.
65    ///
66    /// # Examples
67    ///
68    ///
69    /// let emitter = `WebsocketEventEmitter::new()`;
70    /// let subscription = emitter.subscribe(|event| {
71    ///     // Handle WebSocket event
72    /// });
73    /// `subscription.unsubscribe()`; // Stop receiving events
74    ///
75    pub fn unsubscribe(self) {
76        self.handle.abort();
77    }
78}
79
80#[derive(Clone)]
81pub enum WebsocketBase {
82    WebsocketApi(Arc<WebsocketApi>),
83    WebsocketStreams(Arc<WebsocketStreams>),
84}
85
86pub struct WebsocketEventEmitter {
87    tx: broadcast::Sender<WebsocketEvent>,
88}
89
90impl Default for WebsocketEventEmitter {
91    fn default() -> Self {
92        Self::new()
93    }
94}
95
96impl WebsocketEventEmitter {
97    #[must_use]
98    pub fn new() -> Self {
99        let (tx, _rx) = broadcast::channel(100);
100        Self { tx }
101    }
102
103    /// Subscribes to WebSocket events and returns a `Subscription` that allows receiving events.
104    ///
105    /// # Arguments
106    ///
107    /// * `callback` - A mutable function that will be called for each received WebSocket event.
108    ///
109    /// # Returns
110    ///
111    /// A `Subscription` that can be used to manage the event subscription.
112    ///
113    /// # Examples
114    ///
115    ///
116    /// let emitter = `WebsocketEventEmitter::new()`;
117    /// let subscription = emitter.subscribe(|event| {
118    ///     // Handle WebSocket event
119    /// });
120    /// // Later, unsubscribe if needed
121    /// `subscription.unsubscribe()`;
122    ///
123    pub fn subscribe<F>(&self, mut callback: F) -> Subscription
124    where
125        F: FnMut(WebsocketEvent) + Send + 'static,
126    {
127        let mut rx = self.tx.subscribe();
128        let handle = spawn(async move {
129            while let Ok(event) = rx.recv().await {
130                callback(event);
131            }
132        });
133        Subscription { handle }
134    }
135
136    /// Sends a WebSocket event to all subscribers of this event emitter.
137    ///
138    /// # Arguments
139    ///
140    /// * `event` - The WebSocket event to be sent.
141    ///
142    /// # Remarks
143    ///
144    /// This method uses a broadcast channel to distribute the event to all registered subscribers.
145    /// If no subscribers are currently listening, the event is silently dropped.
146    fn emit(&self, event: WebsocketEvent) {
147        let _ = self.tx.send(event);
148    }
149}
150
151/// A trait defining the lifecycle and behavior of a WebSocket connection.
152///
153/// This trait provides methods for handling WebSocket connection events,
154/// including connection opening, message handling, and reconnection URL retrieval.
155///
156/// # Methods
157///
158/// * `on_open`: Called when a WebSocket connection is established
159/// * `on_message`: Called when a message is received over the WebSocket
160/// * `get_reconnect_url`: Determines the URL to use for reconnecting
161///
162/// # Thread Safety
163///
164/// Implementors must be safely shareable across threads, as indicated by the `Send + Sync + 'static` bounds.
165#[async_trait]
166pub trait WebsocketHandler: Send + Sync + 'static {
167    async fn on_open(&self, url: String, connection: Arc<WebsocketConnection>);
168    async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>);
169    async fn get_reconnect_url(
170        &self,
171        default_url: String,
172        connection: Arc<WebsocketConnection>,
173    ) -> String;
174}
175
176pub struct PendingRequest {
177    pub completion: oneshot::Sender<Result<Value, WebsocketError>>,
178}
179
180pub struct WebsocketConnectionState {
181    pub reconnection_pending: bool,
182    pub renewal_pending: bool,
183    pub close_initiated: bool,
184    pub pending_requests: HashMap<String, PendingRequest>,
185    pub pending_subscriptions: VecDeque<String>,
186    pub stream_callbacks: HashMap<String, Vec<Arc<dyn Fn(&Value) + Send + Sync + 'static>>>,
187    pub handler: Option<Arc<dyn WebsocketHandler>>,
188    pub ws_write_tx: Option<UnboundedSender<Message>>,
189}
190
191impl Default for WebsocketConnectionState {
192    fn default() -> Self {
193        Self::new()
194    }
195}
196
197impl WebsocketConnectionState {
198    #[must_use]
199    pub fn new() -> Self {
200        Self {
201            reconnection_pending: false,
202            renewal_pending: false,
203            close_initiated: false,
204            pending_requests: HashMap::new(),
205            pending_subscriptions: VecDeque::new(),
206            stream_callbacks: HashMap::new(),
207            handler: None,
208            ws_write_tx: None,
209        }
210    }
211}
212
213pub struct WebsocketConnection {
214    pub id: String,
215    pub drain_notify: Notify,
216    pub state: Mutex<WebsocketConnectionState>,
217}
218
219impl WebsocketConnection {
220    pub fn new(id: impl Into<String>) -> Arc<Self> {
221        Arc::new(Self {
222            id: id.into(),
223            drain_notify: Notify::new(),
224            state: Mutex::new(WebsocketConnectionState::new()),
225        })
226    }
227
228    pub async fn set_handler(&self, handler: Arc<dyn WebsocketHandler>) {
229        let mut conn_state = self.state.lock().await;
230        conn_state.handler = Some(handler);
231    }
232}
233
234struct ReconnectEntry {
235    connection_id: String,
236    url: String,
237    is_renewal: bool,
238}
239
240pub struct WebsocketCommon {
241    pub events: WebsocketEventEmitter,
242    mode: WebsocketMode,
243    round_robin_index: AtomicUsize,
244    connection_pool: Vec<Arc<WebsocketConnection>>,
245    reconnect_tx: Sender<ReconnectEntry>,
246    renewal_tx: Sender<(String, String)>,
247    reconnect_delay: usize,
248    agent: Option<AgentConnector>,
249    user_agent: Option<String>,
250}
251
252impl WebsocketCommon {
253    #[must_use]
254    pub fn new(
255        mut initial_pool: Vec<Arc<WebsocketConnection>>,
256        mode: WebsocketMode,
257        reconnect_delay: usize,
258        agent: Option<AgentConnector>,
259        user_agent: Option<String>,
260    ) -> Arc<Self> {
261        if initial_pool.is_empty() {
262            for _ in 0..mode.pool_size() {
263                let id = random_string();
264                initial_pool.push(WebsocketConnection::new(id));
265            }
266        }
267
268        let (reconnect_tx, reconnect_rx) = channel::<ReconnectEntry>(mode.pool_size());
269        let (renewal_tx, renewal_rx) = channel::<(String, String)>(mode.pool_size());
270
271        let common = Arc::new(Self {
272            events: WebsocketEventEmitter::new(),
273            mode,
274            round_robin_index: AtomicUsize::new(0),
275            connection_pool: initial_pool,
276            reconnect_tx,
277            renewal_tx,
278            reconnect_delay,
279            agent,
280            user_agent,
281        });
282
283        Self::spawn_reconnect_loop(Arc::clone(&common), reconnect_rx);
284        Self::spawn_renewal_loop(&Arc::clone(&common), renewal_rx);
285
286        common
287    }
288
289    /// Spawns an asynchronous loop to handle websocket reconnection attempts
290    ///
291    /// This method manages reconnection logic for websocket connections, including:
292    /// - Scheduling reconnects with a configurable delay
293    /// - Finding the appropriate connection in the connection pool
294    /// - Attempting to reinitialize the connection
295    /// - Logging reconnection failures or warnings
296    ///
297    /// # Arguments
298    /// * `common` - A shared reference to the `WebsocketCommon` instance
299    /// * `reconnect_rx` - A receiver channel for reconnection entries
300    ///
301    /// # Behavior
302    /// - Waits for reconnection entries from the channel
303    /// - Applies a configurable delay before attempting reconnection
304    /// - Attempts to reinitialize the connection with the provided URL
305    /// - Handles and logs any reconnection errors
306    fn spawn_reconnect_loop(common: Arc<Self>, mut reconnect_rx: Receiver<ReconnectEntry>) {
307        spawn(async move {
308            while let Some(entry) = reconnect_rx.recv().await {
309                info!("Scheduling reconnect for id {}", entry.connection_id);
310
311                if !entry.is_renewal {
312                    sleep(Duration::from_millis(common.reconnect_delay as u64)).await;
313                }
314
315                if let Some(conn_arc) = common
316                    .connection_pool
317                    .iter()
318                    .find(|c| c.id == entry.connection_id)
319                    .cloned()
320                {
321                    let common_clone = Arc::clone(&common);
322                    if let Err(err) = common_clone
323                        .init_connect(&entry.url, entry.is_renewal, Some(conn_arc.clone()))
324                        .await
325                    {
326                        error!(
327                            "Reconnect failed for {} → {}: {:?}",
328                            entry.connection_id, entry.url, err
329                        );
330                    }
331
332                    sleep(Duration::from_secs(1)).await;
333                } else {
334                    warn!("No connection {} found for reconnect", entry.connection_id);
335                }
336            }
337        });
338    }
339
340    /// Spawns an asynchronous loop to manage connection renewals
341    ///
342    /// This method handles the periodic renewal of websocket connections by:
343    /// - Maintaining a delay queue for connection expiration
344    /// - Receiving renewal requests for specific connections
345    /// - Triggering reconnection when a connection reaches its maximum duration
346    /// - Attempting to find and renew connections in the connection pool
347    ///
348    /// # Behavior
349    /// - Listens for renewal requests on a channel
350    /// - Tracks connection expiration using a delay queue
351    /// - Initiates reconnection process when a connection expires
352    /// - Handles and logs any renewal failures
353    fn spawn_renewal_loop(common: &Arc<Self>, renewal_rx: Receiver<(String, String)>) {
354        let common = Arc::clone(common);
355        spawn(async move {
356            let mut dq = DelayQueue::new();
357            let mut renewal_rx = renewal_rx;
358
359            loop {
360                select! {
361                    Some((conn_id, url)) = renewal_rx.recv() => {
362                        debug!("Scheduling renewal for {}", conn_id);
363                        dq.insert((conn_id, url), MAX_CONN_DURATION);
364                    }
365
366                    Some(expired) = dq.next() => {
367                        let (conn_id, default_url) = expired.into_inner();
368
369                        if let Some(conn_arc) = common
370                            .connection_pool
371                            .iter()
372                            .find(|c| c.id == conn_id)
373                            .cloned()
374                        {
375                            debug!("Renewing connection {}", conn_id);
376                            let url = common
377                                .get_reconnect_url(&default_url, Arc::clone(&conn_arc))
378                                .await;
379                            if let Err(e) = common.reconnect_tx.send(ReconnectEntry {
380                                connection_id: conn_id.clone(),
381                                url,
382                                is_renewal: true,
383                            }).await {
384                                error!(
385                                    "Failed to enqueue renewal for {}: {:?}",
386                                    conn_id, e
387                                );
388                            }
389                        } else {
390                            warn!("No connection {} found for renewal", conn_id);
391                        }
392                    }
393                }
394            }
395        });
396    }
397
398    /// Checks if a WebSocket connection is ready for use.
399    ///
400    /// # Arguments
401    ///
402    /// * `connection` - The WebSocket connection to check
403    /// * `allow_non_established` - If true, allows connections that are not fully established
404    ///
405    /// # Returns
406    ///
407    /// `true` if the connection is ready, `false` otherwise
408    ///
409    /// # Behavior
410    ///
411    /// A connection is considered ready if:
412    /// - It has a write channel (unless `allow_non_established` is true)
413    /// - No renewal is pending
414    /// - No reconnection is pending
415    /// - No close has been initiated
416    pub async fn is_connection_ready(
417        &self,
418        connection: &WebsocketConnection,
419        allow_non_established: bool,
420    ) -> bool {
421        let conn_state = connection.state.lock().await;
422        (allow_non_established || conn_state.ws_write_tx.is_some())
423            && !conn_state.renewal_pending
424            && !conn_state.reconnection_pending
425            && !conn_state.close_initiated
426    }
427
428    /// Checks if a WebSocket connection is established.
429    ///
430    /// # Arguments
431    ///
432    /// * `connection` - Optional specific WebSocket connection to check
433    ///
434    /// # Returns
435    ///
436    /// `true` if a connection is ready and established, `false` otherwise
437    ///
438    /// # Behavior
439    ///
440    /// - If a specific connection is provided, checks only that connection
441    /// - If no connection is provided, checks all connections in the pool
442    /// - A connection is considered established if it is ready and not in a non-established state
443    async fn is_connected(&self, connection: Option<&Arc<WebsocketConnection>>) -> bool {
444        if let Some(conn_arc) = connection {
445            return self.is_connection_ready(conn_arc, false).await;
446        }
447
448        for conn_arc in &self.connection_pool {
449            if self.is_connection_ready(conn_arc, false).await {
450                return true;
451            }
452        }
453
454        false
455    }
456
457    /// Retrieves a WebSocket connection from the connection pool.
458    ///
459    /// # Arguments
460    ///
461    /// * `allow_non_established` - If `true`, allows selecting a connection that is not fully established
462    ///
463    /// # Returns
464    ///
465    /// An `Arc` to a `WebsocketConnection` from the pool, selected using round-robin strategy
466    ///
467    /// # Errors
468    ///
469    /// Returns `WebsocketError::NotConnected` if no suitable connection is available
470    ///
471    /// # Behavior
472    ///
473    /// - For single connection mode, returns the first connection
474    /// - For multi-connection mode, selects a ready connection using round-robin
475    /// - Filters connections based on `allow_non_established` parameter
476    async fn get_connection(
477        &self,
478        allow_non_established: bool,
479    ) -> Result<Arc<WebsocketConnection>, WebsocketError> {
480        if let WebsocketMode::Single = self.mode {
481            return Ok(Arc::clone(&self.connection_pool[0]));
482        }
483
484        let mut ready = Vec::new();
485        for conn in &self.connection_pool {
486            if self.is_connection_ready(conn, allow_non_established).await {
487                ready.push(Arc::clone(conn));
488            }
489        }
490
491        if ready.is_empty() {
492            return Err(WebsocketError::NotConnected);
493        }
494
495        let idx = self.round_robin_index.fetch_add(1, Ordering::Relaxed) % ready.len();
496
497        Ok(Arc::clone(&ready[idx]))
498    }
499
500    /// Gracefully closes a WebSocket connection by waiting for pending requests to complete.
501    ///
502    /// # Arguments
503    ///
504    /// * `ws_write_tx_to_close` - Sender channel for sending close message
505    /// * `connection` - Shared reference to the WebSocket connection
506    ///
507    /// # Behavior
508    ///
509    /// - Waits up to 30 seconds for all pending requests to complete
510    /// - Logs debug and warning messages during the closing process
511    /// - Sends a normal close frame to the WebSocket
512    ///
513    /// # Returns
514    ///
515    /// `Ok(())` if connection closes successfully, otherwise a `WebsocketError`
516    async fn close_connection_gracefully(
517        &self,
518        ws_write_tx_to_close: UnboundedSender<Message>,
519        connection: Arc<WebsocketConnection>,
520    ) -> Result<(), WebsocketError> {
521        debug!("Waiting for pending requests to complete before disconnecting.");
522
523        let drain = async {
524            loop {
525                {
526                    let conn_state = connection.state.lock().await;
527                    if conn_state.pending_requests.is_empty() {
528                        debug!("All pending requests completed, proceeding to close.");
529                        break;
530                    }
531                }
532                connection.drain_notify.notified().await;
533            }
534        };
535
536        if timeout(Duration::from_secs(30), drain).await.is_err() {
537            warn!("Timeout waiting for pending requests; forcing close.");
538        }
539
540        info!("Closing WebSocket connection for {}", connection.id);
541        let _ = ws_write_tx_to_close.send(Message::Close(Some(CloseFrame {
542            code: CloseCode::Normal,
543            reason: "".into(),
544        })));
545
546        Ok(())
547    }
548
549    /// Retrieves the URL to use for reconnecting to the WebSocket.
550    ///
551    /// # Arguments
552    ///
553    /// * `default_url` - The default URL to use if no custom reconnect URL is provided
554    /// * `connection` - A shared reference to the WebSocket connection
555    ///
556    /// # Returns
557    ///
558    /// The URL to use for reconnecting, either from a custom handler or the default URL
559    ///
560    /// # Behavior
561    ///
562    /// - Checks if a connection handler is available
563    /// - If a handler exists, calls its `get_reconnect_url` method
564    /// - Otherwise, returns the default URL
565    async fn get_reconnect_url(
566        &self,
567        default_url: &str,
568        connection: Arc<WebsocketConnection>,
569    ) -> String {
570        if let Some(handler) = {
571            let conn_state = connection.state.lock().await;
572            conn_state.handler.clone()
573        } {
574            return handler
575                .get_reconnect_url(default_url.to_string(), Arc::clone(&connection))
576                .await;
577        }
578
579        default_url.to_string()
580    }
581
582    /// Handles the WebSocket connection opening event.
583    ///
584    /// This method is called when a WebSocket connection is successfully established. It performs
585    /// the following key actions:
586    /// - Invokes the connection handler's `on_open` method if a handler is present
587    /// - Logs connection information
588    /// - Handles connection renewal and close scenarios
589    /// - Emits a WebSocket open event
590    ///
591    /// # Arguments
592    ///
593    /// * `url` - The URL of the WebSocket server
594    /// * `connection` - A shared reference to the WebSocket connection
595    /// * `old_ws_writer` - Optional previous WebSocket writer for graceful connection handling
596    ///
597    /// # Behavior
598    ///
599    /// - If a connection handler exists, calls its `on_open` method
600    /// - Checks for pending renewal or close states
601    /// - Closes the previous connection if renewal is in progress
602    /// - Emits an open event if the connection is successfully established
603    async fn on_open(
604        &self,
605        url: String,
606        connection: Arc<WebsocketConnection>,
607        old_ws_writer: Option<UnboundedSender<Message>>,
608    ) {
609        if let Some(handler) = {
610            let conn_state = connection.state.lock().await;
611            conn_state.handler.clone()
612        } {
613            handler.on_open(url.clone(), Arc::clone(&connection)).await;
614        }
615
616        let conn_id = &connection.id;
617        info!("Connected to WebSocket Server with id {}: {}", conn_id, url);
618
619        {
620            let mut conn_state = connection.state.lock().await;
621
622            if conn_state.renewal_pending {
623                conn_state.renewal_pending = false;
624                drop(conn_state);
625                if let Some(tx) = old_ws_writer {
626                    info!("Connection renewal in progress; closing previous connection.");
627                    let _ = self
628                        .close_connection_gracefully(tx, Arc::clone(&connection))
629                        .await;
630                }
631                return;
632            }
633
634            if conn_state.close_initiated {
635                drop(conn_state);
636                if let Some(tx) = connection.state.lock().await.ws_write_tx.clone() {
637                    info!("Close initiated; closing connection.");
638                    let _ = self
639                        .close_connection_gracefully(tx, Arc::clone(&connection))
640                        .await;
641                }
642                return;
643            }
644
645            self.events.emit(WebsocketEvent::Open);
646        }
647    }
648
649    /// Handles an incoming WebSocket message
650    ///
651    /// # Arguments
652    ///
653    /// * `msg` - The received message as a string
654    /// * `connection` - A shared reference to the WebSocket connection
655    ///
656    /// # Behavior
657    ///
658    /// - If a connection handler exists, spawns an async task to call its `on_message` method
659    /// - Emits a `WebsocketEvent::Message` event with the received message
660    async fn on_message(&self, msg: String, connection: Arc<WebsocketConnection>) {
661        if let Some(handler) = connection.state.lock().await.handler.clone() {
662            let handler_clone = handler.clone();
663            let data = msg.clone();
664            let conn_clone = connection.clone();
665            spawn(async move {
666                handler_clone.on_message(data, conn_clone).await;
667            });
668        }
669        self.events.emit(WebsocketEvent::Message(msg));
670    }
671
672    /// Creates a WebSocket connection with optional configuration and agent
673    ///
674    /// # Arguments
675    ///
676    /// * `url` - The WebSocket server URL to connect to
677    /// * `agent` - Optional agent connector for configuring the connection
678    /// * `user_agent` - Optional custom user agent string
679    ///
680    /// # Returns
681    ///
682    /// A `Result` containing the established WebSocket stream or a `WebsocketError`
683    ///
684    /// # Errors
685    ///
686    /// Returns a `WebsocketError` if:
687    /// - The WebSocket handshake fails
688    /// - The connection times out after 10 seconds
689    ///
690    /// # Behavior
691    ///
692    /// Attempts to establish a WebSocket connection with a configurable timeout,
693    /// supporting optional TLS, custom user agent, and connection connectors
694    async fn create_websocket(
695        url: &str,
696        agent: Option<AgentConnector>,
697        user_agent: Option<String>,
698    ) -> Result<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>, WebsocketError> {
699        let mut req = url
700            .into_client_request()
701            .map_err(|e| WebsocketError::Handshake(e.to_string()))?;
702
703        if let Some(ua) = user_agent {
704            req.headers_mut().insert(USER_AGENT, ua.parse().unwrap());
705        }
706
707        let ws_config: Option<WebSocketConfig> = None;
708        let disable_nagle = false;
709        let connector: Option<Connector> = agent.map(|dbg| dbg.0);
710
711        let timeout_duration = Duration::from_secs(10);
712        let handshake = connect_async_tls_with_config(req, ws_config, disable_nagle, connector);
713        match timeout(timeout_duration, handshake).await {
714            Ok(Ok((ws_stream, response))) => {
715                debug!("WebSocket connected: {:?}", response);
716                Ok(ws_stream)
717            }
718            Ok(Err(e)) => {
719                let msg = e.to_string();
720                error!("WebSocket handshake failed: {}", msg);
721                Err(WebsocketError::Handshake(msg))
722            }
723            Err(_) => {
724                error!(
725                    "WebSocket connection timed out after {}s",
726                    timeout_duration.as_secs()
727                );
728                Err(WebsocketError::Timeout)
729            }
730        }
731    }
732
733    /// Connects to a WebSocket URL for all connections in the connection pool concurrently
734    ///
735    /// # Arguments
736    ///
737    /// * `url` - The WebSocket server URL to connect to
738    ///
739    /// # Returns
740    ///
741    /// A `Result` indicating whether all connections were successfully established
742    ///
743    /// # Errors
744    ///
745    /// Returns a `WebsocketError` if any connection in the pool fails to establish
746    ///
747    /// # Behavior
748    ///
749    /// Attempts to initialize a WebSocket connection for each connection in the pool
750    /// concurrently, logging successes and failures for each connection attempt
751    async fn connect_pool(self: Arc<Self>, url: &str) -> Result<(), WebsocketError> {
752        let mut tasks = FuturesUnordered::new();
753
754        for conn in &self.connection_pool {
755            let common = Arc::clone(&self);
756            let url = url.to_owned();
757            let conn_clone = Arc::clone(conn);
758
759            tasks.push(async move {
760                match common.init_connect(&url, false, Some(conn_clone)).await {
761                    Ok(()) => {
762                        info!("Successfully connected to {}", url);
763                        Ok(())
764                    }
765                    Err(err) => {
766                        error!("Failed to connect to {}: {:?}", url, err);
767                        Err(err)
768                    }
769                }
770            });
771        }
772
773        while let Some(result) = tasks.next().await {
774            result?;
775        }
776
777        Ok(())
778    }
779
780    /// Initializes a WebSocket connection for a specific connection in the pool
781    ///
782    /// # Arguments
783    ///
784    /// * `url` - The WebSocket server URL to connect to
785    /// * `is_renewal` - Flag indicating whether this is a connection renewal attempt
786    /// * `connection` - Optional specific WebSocket connection to use, otherwise selects from the pool
787    ///
788    /// # Returns
789    ///
790    /// A `Result` indicating whether the connection was successfully established
791    ///
792    /// # Errors
793    ///
794    /// Returns a `WebsocketError` if the connection fails to initialize or establish
795    ///
796    /// # Behavior
797    ///
798    /// Handles connection establishment, splitting read/write streams, spawning reader/writer tasks,
799    /// and managing connection state including renewal, reconnection, and error handling
800    async fn init_connect(
801        self: Arc<Self>,
802        url: &str,
803        is_renewal: bool,
804        connection: Option<Arc<WebsocketConnection>>,
805    ) -> Result<(), WebsocketError> {
806        let conn = connection.unwrap_or(self.get_connection(true).await?);
807
808        {
809            let mut conn_state = conn.state.lock().await;
810            if conn_state.renewal_pending && is_renewal {
811                info!("Renewal in progress {}→{}", conn.id, url);
812                return Ok(());
813            }
814            if conn_state.ws_write_tx.is_some() && !is_renewal {
815                info!("Exists {}; skipping {}", conn.id, url);
816                return Ok(());
817            }
818            if is_renewal {
819                conn_state.renewal_pending = true;
820            }
821        }
822
823        let ws = Self::create_websocket(url, self.agent.clone(), self.user_agent.clone())
824            .await
825            .map_err(|e| {
826                error!("Handshake failed {}: {:?}", url, e);
827                e
828            })?;
829
830        info!("Established {} → {}", conn.id, url);
831
832        if let Err(e) = self.renewal_tx.try_send((conn.id.clone(), url.to_string())) {
833            error!("Failed to schedule renewal for {}: {:?}", conn.id, e);
834        }
835
836        let (write_half, mut read_half) = ws.split();
837        let (tx, mut rx) = unbounded_channel::<Message>();
838
839        let old_writer = {
840            let mut conn_state = conn.state.lock().await;
841            conn_state.ws_write_tx.replace(tx.clone())
842        };
843
844        let wconn = conn.clone();
845
846        spawn(async move {
847            let mut sink = write_half;
848            while let Some(msg) = rx.recv().await {
849                if sink.send(msg).await.is_err() {
850                    error!("Write error {}", wconn.id);
851                    break;
852                }
853            }
854            debug!("Writer {} exit", wconn.id);
855        });
856
857        self.on_open(url.to_string(), conn.clone(), old_writer)
858            .await;
859
860        let common = self.clone();
861        let reader_conn = conn.clone();
862        let read_url = url.to_string();
863
864        spawn(async move {
865            while let Some(item) = read_half.next().await {
866                match item {
867                    Ok(Message::Text(msg)) => {
868                        common
869                            .on_message(msg.to_string(), Arc::clone(&reader_conn))
870                            .await;
871                    }
872                    Ok(Message::Binary(bin)) => {
873                        let mut decoder = ZlibDecoder::new(&bin[..]);
874                        let mut decompressed = String::new();
875                        if let Err(err) = decoder.read_to_string(&mut decompressed) {
876                            error!("Binary message decompress failed: {:?}", err);
877                            continue;
878                        }
879                        common
880                            .on_message(decompressed, Arc::clone(&reader_conn))
881                            .await;
882                    }
883                    Ok(Message::Ping(payload)) => {
884                        info!("PING received from server on {}", reader_conn.id);
885                        common.events.emit(WebsocketEvent::Ping);
886                        if let Some(tx) = reader_conn.state.lock().await.ws_write_tx.clone() {
887                            let _ = tx.send(Message::Pong(payload));
888                            info!(
889                                "Responded PONG to server's PING message on {}",
890                                reader_conn.id
891                            );
892                        }
893                    }
894                    Ok(Message::Pong(_)) => {
895                        info!("Received PONG from server on {}", reader_conn.id);
896                        common.events.emit(WebsocketEvent::Pong);
897                    }
898                    Ok(Message::Close(frame)) => {
899                        let (code, reason) = frame
900                            .map_or((1000, String::new()), |CloseFrame { code, reason }| {
901                                (code.into(), reason.to_string())
902                            });
903                        common
904                            .events
905                            .emit(WebsocketEvent::Close(code, reason.clone()));
906
907                        let mut conn_state = reader_conn.state.lock().await;
908                        if !conn_state.close_initiated
909                            && !is_renewal
910                            && CloseCode::from(code) != CloseCode::Normal
911                        {
912                            warn!(
913                                "Connection {} closed due to {}: {}",
914                                reader_conn.id, code, reason
915                            );
916                            conn_state.reconnection_pending = true;
917                            drop(conn_state);
918                            let reconnect_url = common
919                                .get_reconnect_url(&read_url, Arc::clone(&reader_conn))
920                                .await;
921
922                            let _ = common
923                                .reconnect_tx
924                                .send(ReconnectEntry {
925                                    connection_id: reader_conn.id.clone(),
926                                    url: reconnect_url,
927                                    is_renewal: false,
928                                })
929                                .await;
930                        }
931                        break;
932                    }
933                    Err(e) => {
934                        error!("WebSocket error on {}: {:?}", reader_conn.id, e);
935                        common.events.emit(WebsocketEvent::Error(e.to_string()));
936                    }
937                    _ => {}
938                }
939            }
940            debug!("Reader actor for {} exiting", reader_conn.id);
941        });
942
943        Ok(())
944    }
945    /// Gracefully disconnects all active WebSocket connections.
946    ///
947    /// This method attempts to close all connections in the connection pool within a 30-second timeout.
948    /// It marks each connection as close-initiated and attempts to close them gracefully.
949    ///
950    /// # Returns
951    ///
952    /// - `Ok(())` if all connections are successfully closed
953    /// - `Err(WebsocketError)` if there are errors during disconnection or a timeout occurs
954    ///
955    /// # Errors
956    ///
957    /// Returns `WebsocketError::Timeout` if disconnection takes longer than 30 seconds
958    ///
959    async fn disconnect(&self) -> Result<(), WebsocketError> {
960        if !self.is_connected(None).await {
961            warn!("No active connection to close.");
962            return Ok(());
963        }
964
965        let mut shutdowns = FuturesUnordered::new();
966        for conn in &self.connection_pool {
967            {
968                let mut conn_state = conn.state.lock().await;
969                conn_state.close_initiated = true;
970                if let Some(tx) = &conn_state.ws_write_tx {
971                    shutdowns.push(self.close_connection_gracefully(tx.clone(), Arc::clone(conn)));
972                }
973            }
974        }
975
976        let close_all = async {
977            while let Some(result) = shutdowns.next().await {
978                result?;
979            }
980            Ok::<(), WebsocketError>(())
981        };
982
983        match timeout(Duration::from_secs(30), close_all).await {
984            Ok(Ok(())) => {
985                info!("Disconnected all WebSocket connections successfully.");
986                Ok(())
987            }
988            Ok(Err(err)) => {
989                error!("Error while disconnecting: {:?}", err);
990                Err(err)
991            }
992            Err(_) => {
993                error!("Timed out while disconnecting WebSocket connections.");
994                Err(WebsocketError::Timeout)
995            }
996        }
997    }
998
999    /// Sends a PING message to all ready WebSocket connections.
1000    ///
1001    /// This method iterates through the connection pool, identifies ready connections,
1002    /// and sends a PING message to each of them. It logs the number of connections
1003    /// being pinged and handles any send errors individually.
1004    ///
1005    /// # Behavior
1006    ///
1007    /// - Skips connections that are not ready
1008    /// - Logs a warning if no connections are ready
1009    /// - Sends PING messages concurrently
1010    /// - Logs debug/error messages for each PING attempt
1011    async fn ping_server(&self) {
1012        let mut ready = Vec::new();
1013        for conn in &self.connection_pool {
1014            if self.is_connection_ready(conn, false).await {
1015                let id = conn.id.clone();
1016                let ws_write_tx = {
1017                    let conn_state = conn.state.lock().await;
1018                    conn_state.ws_write_tx.clone()
1019                };
1020                ready.push((id, ws_write_tx));
1021            }
1022        }
1023
1024        if ready.is_empty() {
1025            warn!("No ready connections for PING.");
1026            return;
1027        }
1028        info!("Sending PING to {} WebSocket connections.", ready.len());
1029
1030        let mut tasks = FuturesUnordered::new();
1031        for (id, ws_write_tx_opt) in ready {
1032            if let Some(tx) = ws_write_tx_opt {
1033                tasks.push(async move {
1034                    if let Err(e) = tx.send(Message::Ping(Vec::new().into())) {
1035                        error!("Failed to send PING to {}: {:?}", id, e);
1036                    } else {
1037                        debug!("Sent PING to connection {}", id);
1038                    }
1039                });
1040            } else {
1041                error!("Connection {} was ready but has no write channel", id);
1042            }
1043        }
1044
1045        while tasks.next().await.is_some() {}
1046    }
1047
1048    /// Sends a WebSocket message and optionally waits for a reply.
1049    ///
1050    /// # Arguments
1051    ///
1052    /// * `payload` - The message payload to send
1053    /// * `id` - Optional request identifier, required when waiting for a reply
1054    /// * `wait_for_reply` - Whether to wait for a response to the message
1055    /// * `timeout` - Maximum duration to wait for a reply
1056    /// * `connection` - Optional specific WebSocket connection to use
1057    ///
1058    /// # Returns
1059    ///
1060    /// A receiver for the response if `wait_for_reply` is true, otherwise `None`
1061    ///
1062    /// # Errors
1063    ///
1064    /// Returns a `WebsocketError` if the connection is not ready or the send fails
1065    async fn send(
1066        &self,
1067        payload: String,
1068        id: Option<String>,
1069        wait_for_reply: bool,
1070        timeout: Duration,
1071        connection: Option<Arc<WebsocketConnection>>,
1072    ) -> Result<Option<oneshot::Receiver<Result<Value, WebsocketError>>>, WebsocketError> {
1073        let conn = if let Some(c) = connection {
1074            c
1075        } else {
1076            self.get_connection(false).await?
1077        };
1078
1079        if !self.is_connected(Some(&conn)).await {
1080            warn!("Send attempted on a non-connected socket");
1081            return Err(WebsocketError::NotConnected);
1082        }
1083
1084        let ws_write_tx = {
1085            let conn_state = conn.state.lock().await;
1086            conn_state
1087                .ws_write_tx
1088                .clone()
1089                .ok_or(WebsocketError::NotConnected)?
1090        };
1091
1092        debug!("Sending message to WebSocket on connection {}", conn.id);
1093
1094        ws_write_tx
1095            .send(Message::Text(payload.clone().into()))
1096            .map_err(|_| WebsocketError::NotConnected)?;
1097
1098        if !wait_for_reply {
1099            return Ok(None);
1100        }
1101
1102        let request_id = id.ok_or_else(|| {
1103            error!("id is required when waiting for a reply");
1104            WebsocketError::NotConnected
1105        })?;
1106
1107        let (tx, rx) = oneshot::channel();
1108        {
1109            let mut conn_state = conn.state.lock().await;
1110            conn_state
1111                .pending_requests
1112                .insert(request_id.clone(), PendingRequest { completion: tx });
1113        }
1114
1115        let conn_clone = Arc::clone(&conn);
1116        spawn(async move {
1117            sleep(timeout).await;
1118            let mut conn_state = conn_clone.state.lock().await;
1119            if let Some(pending_req) = conn_state.pending_requests.remove(&request_id) {
1120                let _ = pending_req.completion.send(Err(WebsocketError::Timeout));
1121            }
1122        });
1123
1124        Ok(Some(rx))
1125    }
1126}
1127
1128pub struct WebsocketMessageSendOptions {
1129    pub with_api_key: bool,
1130    pub is_signed: bool,
1131}
1132
1133pub struct WebsocketApi {
1134    pub common: Arc<WebsocketCommon>,
1135    configuration: ConfigurationWebsocketApi,
1136    is_connecting: Arc<Mutex<bool>>,
1137    stream_callbacks: Mutex<HashMap<String, Vec<Arc<dyn Fn(&Value) + Send + Sync + 'static>>>>,
1138}
1139
1140impl WebsocketApi {
1141    #[must_use]
1142    /// Creates a new WebSocket API instance with the given configuration and connection pool.
1143    ///
1144    /// # Arguments
1145    ///
1146    /// * `configuration` - Configuration settings for the WebSocket API
1147    /// * `connection_pool` - A vector of WebSocket connections to be used
1148    ///
1149    /// # Returns
1150    ///
1151    /// An `Arc`-wrapped `WebsocketApi` instance ready for use
1152    ///
1153    /// # Panics
1154    ///
1155    /// This function will panic if the configuration is not valid.
1156    ///
1157    /// # Examples
1158    ///
1159    ///
1160    /// let api = `WebsocketApi::new(config`, `connection_pool`);
1161    ///
1162    pub fn new(
1163        configuration: ConfigurationWebsocketApi,
1164        connection_pool: Vec<Arc<WebsocketConnection>>,
1165    ) -> Arc<Self> {
1166        let agent_clone = configuration.agent.clone();
1167        let user_agent_clone = configuration.user_agent.clone();
1168        let common = WebsocketCommon::new(
1169            connection_pool,
1170            configuration.mode.clone(),
1171            usize::try_from(configuration.reconnect_delay)
1172                .expect("reconnect_delay should fit in usize"),
1173            agent_clone,
1174            Some(user_agent_clone),
1175        );
1176
1177        Arc::new(Self {
1178            common: Arc::clone(&common),
1179            configuration,
1180            is_connecting: Arc::new(Mutex::new(false)),
1181            stream_callbacks: Mutex::new(HashMap::new()),
1182        })
1183    }
1184
1185    /// Connects to a WebSocket server with a configurable timeout and connection handling.
1186    ///
1187    /// This method attempts to establish a WebSocket connection if not already connected.
1188    /// It prevents multiple simultaneous connection attempts and supports a connection pool.
1189    ///
1190    /// # Errors
1191    ///
1192    /// Returns a `WebsocketError` if:
1193    /// - Connection fails
1194    /// - Connection times out after 10 seconds
1195    ///
1196    /// # Behavior
1197    ///
1198    /// - Checks if already connected and returns early if so
1199    /// - Prevents multiple concurrent connection attempts
1200    /// - Sets a WebSocket handler for the connection pool
1201    /// - Attempts to connect with a 10-second timeout
1202    ///
1203    /// # Returns
1204    ///
1205    /// `Ok(())` if connection is successful, otherwise a `WebsocketError`
1206    pub async fn connect(self: Arc<Self>) -> Result<(), WebsocketError> {
1207        if self.common.is_connected(None).await {
1208            info!("WebSocket connection already established");
1209            return Ok(());
1210        }
1211
1212        {
1213            let mut flag = self.is_connecting.lock().await;
1214            if *flag {
1215                info!("Already connecting...");
1216                return Ok(());
1217            }
1218            *flag = true;
1219        }
1220
1221        let url = self.prepare_url(self.configuration.ws_url.as_deref().unwrap_or_default());
1222
1223        let handler: Arc<dyn WebsocketHandler> = self.clone();
1224        for slot in &self.common.connection_pool {
1225            slot.set_handler(handler.clone()).await;
1226        }
1227
1228        let result = select! {
1229            () = sleep(Duration::from_secs(10)) => Err(WebsocketError::Timeout),
1230            r = self.common.clone().connect_pool(&url) => r,
1231        };
1232
1233        {
1234            let mut flag = self.is_connecting.lock().await;
1235            *flag = false;
1236        }
1237
1238        result
1239    }
1240
1241    /// Disconnects the WebSocket connection.
1242    ///
1243    /// # Returns
1244    ///
1245    /// `Ok(())` if disconnection is successful, otherwise a `WebsocketError`
1246    ///
1247    /// # Errors
1248    ///
1249    /// Returns a `WebsocketError` if:
1250    /// - Disconnection fails
1251    /// - Connection is not established
1252    ///
1253    pub async fn disconnect(&self) -> Result<(), WebsocketError> {
1254        self.common.disconnect().await
1255    }
1256
1257    /// Checks if the WebSocket connection is currently established.
1258    ///
1259    /// # Returns
1260    ///
1261    /// `true` if the connection is active, `false` otherwise.
1262    pub async fn is_connected(&self) -> bool {
1263        self.common.is_connected(None).await
1264    }
1265
1266    /// Sends a ping to the WebSocket server to maintain the connection.
1267    ///
1268    /// This method calls the underlying connection's ping mechanism to check
1269    /// and keep the WebSocket connection alive.
1270    pub async fn ping_server(&self) {
1271        self.common.ping_server().await;
1272    }
1273
1274    /// Sends a WebSocket message with the specified method and payload.
1275    ///
1276    /// This method prepares and sends a WebSocket request with optional API key and signature.
1277    /// It handles connection status, generates a unique request ID, and processes the response.
1278    ///
1279    /// # Arguments
1280    ///
1281    /// * `method` - The WebSocket API method to be called
1282    /// * `payload` - A map of parameters to be sent with the request
1283    /// * `options` - Configuration options for message sending (API key, signing)
1284    ///
1285    /// # Returns
1286    ///
1287    /// A deserialized response of type `R` or a `WebsocketError` if the request fails
1288    ///
1289    /// # Panics
1290    ///
1291    /// Panics if:
1292    ///
1293    /// - The WebSocket is not connected
1294    /// - The request cannot be processed
1295    /// - No response is received within the timeout
1296    ///
1297    /// # Errors
1298    ///
1299    /// Returns `WebsocketError` if:
1300    /// - The WebSocket is not connected
1301    /// - The request cannot be processed
1302    /// - No response is received within the timeout
1303    pub async fn send_message<R>(
1304        &self,
1305        method: &str,
1306        mut payload: BTreeMap<String, Value>,
1307        options: WebsocketMessageSendOptions,
1308    ) -> Result<WebsocketApiResponse<R>, WebsocketError>
1309    where
1310        R: DeserializeOwned + Send + Sync + 'static,
1311    {
1312        if !self.common.is_connected(None).await {
1313            return Err(WebsocketError::NotConnected);
1314        }
1315
1316        let id = payload
1317            .get("id")
1318            .and_then(Value::as_str)
1319            .filter(|s| ID_REGEX.is_match(s))
1320            .map_or_else(random_string, String::from);
1321
1322        payload.remove("id");
1323
1324        let mut params = remove_empty_value(payload.into_iter());
1325        if options.with_api_key || options.is_signed {
1326            params.insert(
1327                "apiKey".into(),
1328                Value::String(
1329                    self.configuration
1330                        .api_key
1331                        .clone()
1332                        .expect("API key must be set"),
1333                ),
1334            );
1335        }
1336        if options.is_signed {
1337            let ts = get_timestamp();
1338            let ts_i64 = i64::try_from(ts).map_err(|e| WebsocketError::Protocol(e.to_string()))?;
1339            params.insert(
1340                "timestamp".into(),
1341                Value::Number(serde_json::Number::from(ts_i64)),
1342            );
1343            let mut sorted_params = sort_object_params(&params);
1344            let sig = self
1345                .configuration
1346                .signature_gen
1347                .get_signature(&sorted_params)
1348                .map_err(|e| WebsocketError::Protocol(e.to_string()))?;
1349            sorted_params.insert("signature".into(), Value::String(sig));
1350            params = sorted_params.into_iter().collect();
1351        }
1352
1353        let request = json!({
1354            "id": id,
1355            "method": method,
1356            "params": params,
1357        });
1358        debug!("Sending message to WebSocket API: {:?}", request);
1359
1360        let timeout = Duration::from_millis(self.configuration.timeout);
1361        let maybe_rx = self
1362            .common
1363            .send(
1364                serde_json::to_string(&request).unwrap(),
1365                Some(id.clone()),
1366                true,
1367                timeout,
1368                None,
1369            )
1370            .await?;
1371
1372        let msg: Value = if let Some(rx) = maybe_rx {
1373            rx.await.unwrap_or(Err(WebsocketError::Timeout))?
1374        } else {
1375            return Err(WebsocketError::NoResponse);
1376        };
1377
1378        let raw = msg
1379            .get("result")
1380            .or_else(|| msg.get("response"))
1381            .cloned()
1382            .unwrap_or(Value::Null);
1383
1384        let rate_limits = msg
1385            .get("rateLimits")
1386            .and_then(Value::as_array)
1387            .map(|arr| {
1388                arr.iter()
1389                    .filter_map(|v| serde_json::from_value(v.clone()).ok())
1390                    .collect()
1391            })
1392            .unwrap_or_default();
1393
1394        Ok(WebsocketApiResponse {
1395            raw,
1396            rate_limits,
1397            _marker: PhantomData,
1398        })
1399    }
1400
1401    /// Prepares a WebSocket URL by appending a validated time unit parameter.
1402    ///
1403    /// This method checks if a time unit is configured and validates it. If valid,
1404    /// the time unit is appended to the URL as a query parameter. If no time unit
1405    /// is specified or the validation fails, the original URL is returned.
1406    ///
1407    /// # Arguments
1408    ///
1409    /// * `ws_url` - The base WebSocket URL to be modified
1410    ///
1411    /// # Returns
1412    ///
1413    /// A modified URL with the time unit parameter, or the original URL if no
1414    /// modification is possible
1415    fn prepare_url(&self, ws_url: &str) -> String {
1416        let mut url = ws_url.to_string();
1417
1418        let time_unit = match &self.configuration.time_unit {
1419            Some(u) => u.to_string(),
1420            None => return url,
1421        };
1422
1423        match validate_time_unit(&time_unit) {
1424            Ok(Some(validated)) => {
1425                let sep = if url.contains('?') { '&' } else { '?' };
1426                url.push(sep);
1427                url.push_str("timeUnit=");
1428                url.push_str(validated);
1429            }
1430            Ok(None) => {}
1431            Err(e) => {
1432                error!("Invalid time unit provided: {:?}", e);
1433            }
1434        }
1435
1436        url
1437    }
1438}
1439
1440#[async_trait]
1441impl WebsocketHandler for WebsocketApi {
1442    /// Callback method invoked when a WebSocket connection is successfully opened.
1443    ///
1444    /// This method is called after a WebSocket connection is established. Currently,
1445    /// it does not perform any actions and serves as a placeholder for potential
1446    /// connection initialization or logging.
1447    ///
1448    /// # Arguments
1449    ///
1450    /// * `_url` - The URL of the WebSocket connection that was opened
1451    /// * `_connection` - An Arc-wrapped WebSocket connection context
1452    ///
1453    /// # Remarks
1454    ///
1455    /// This method can be overridden by implementations to add custom logic
1456    /// when a WebSocket connection is first opened.
1457    async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {}
1458
1459    /// Handles incoming WebSocket messages by parsing the JSON payload and processing pending requests.
1460    ///
1461    /// This method is responsible for:
1462    /// - Parsing the received WebSocket message as JSON
1463    /// - Matching the message to a pending request by its ID
1464    /// - Sending the response back to the original request's completion channel
1465    /// - Handling both successful and error responses
1466    ///
1467    /// # Arguments
1468    ///
1469    /// * `data` - The raw WebSocket message as a string
1470    /// * `connection` - The WebSocket connection context associated with the message
1471    ///
1472    /// # Behavior
1473    ///
1474    /// - If message parsing fails, logs an error and returns
1475    /// - For known request IDs, sends the response to the corresponding completion channel
1476    /// - Warns about responses for unknown or timed-out requests
1477    /// - Differentiates between successful (status < 400) and error responses
1478    async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>) {
1479        let msg: Value = match serde_json::from_str(&data) {
1480            Ok(v) => v,
1481            Err(err) => {
1482                error!("Failed to parse WebSocket message {} – {}", data, err);
1483                return;
1484            }
1485        };
1486
1487        if let Some(id) = msg.get("id").and_then(Value::as_str) {
1488            let maybe_sender = {
1489                let mut conn_state = connection.state.lock().await;
1490                conn_state.pending_requests.remove(id)
1491            };
1492
1493            if let Some(PendingRequest { completion }) = maybe_sender {
1494                connection.drain_notify.notify_one();
1495                let status = msg.get("status").and_then(Value::as_u64).unwrap_or(200);
1496                if status >= 400 {
1497                    let error_map = msg
1498                        .get("error")
1499                        .and_then(Value::as_object)
1500                        .unwrap_or(&serde_json::Map::new())
1501                        .clone();
1502
1503                    let code = error_map
1504                        .get("code")
1505                        .and_then(Value::as_i64)
1506                        .unwrap_or(status.try_into().unwrap());
1507
1508                    let message = error_map
1509                        .get("msg")
1510                        .and_then(Value::as_str)
1511                        .unwrap_or("Unknown error")
1512                        .to_string();
1513
1514                    let _ = completion.send(Err(WebsocketError::ResponseError { code, message }));
1515                } else {
1516                    let _ = completion.send(Ok(msg.clone()));
1517                }
1518            }
1519
1520            return;
1521        }
1522
1523        if let Some(event) = msg.get("event") {
1524            if event.get("e").is_some() {
1525                for callbacks in self.stream_callbacks.lock().await.values() {
1526                    for callback in callbacks {
1527                        callback(event);
1528                    }
1529                }
1530
1531                return;
1532            }
1533        }
1534
1535        warn!(
1536            "Received response for unknown or timed-out request: {}",
1537            data
1538        );
1539    }
1540
1541    /// Generates the URL to use for reconnecting to a WebSocket connection.
1542    ///
1543    /// # Arguments
1544    ///
1545    /// * `default_url` - The original URL to potentially modify for reconnection
1546    /// * `_connection` - The WebSocket connection context (currently unused)
1547    ///
1548    /// # Returns
1549    ///
1550    /// A `String` representing the URL to use for reconnecting
1551    async fn get_reconnect_url(
1552        &self,
1553        default_url: String,
1554        _connection: Arc<WebsocketConnection>,
1555    ) -> String {
1556        default_url
1557    }
1558}
1559
1560pub struct WebsocketStreams {
1561    pub common: Arc<WebsocketCommon>,
1562    is_connecting: Mutex<bool>,
1563    connection_streams: Mutex<HashMap<String, Arc<WebsocketConnection>>>,
1564    configuration: ConfigurationWebsocketStreams,
1565}
1566
1567impl WebsocketStreams {
1568    /// Creates a new `WebsocketStreams` instance with the given configuration and connection pool.
1569    ///
1570    /// # Arguments
1571    ///
1572    /// * `configuration` - Configuration settings for the WebSocket streams
1573    /// * `connection_pool` - A vector of WebSocket connections to use
1574    ///
1575    /// # Returns
1576    ///
1577    /// An `Arc`-wrapped `WebsocketStreams` instance
1578    ///
1579    /// # Panics
1580    ///
1581    /// Panics if the `reconnect_delay` cannot be converted to `usize`
1582    #[must_use]
1583    pub fn new(
1584        configuration: ConfigurationWebsocketStreams,
1585        connection_pool: Vec<Arc<WebsocketConnection>>,
1586    ) -> Arc<Self> {
1587        let agent_clone = configuration.agent.clone();
1588        let user_agent_clone = configuration.user_agent.clone();
1589        let common = WebsocketCommon::new(
1590            connection_pool,
1591            configuration.mode.clone(),
1592            usize::try_from(configuration.reconnect_delay)
1593                .expect("reconnect_delay should fit in usize"),
1594            agent_clone,
1595            Some(user_agent_clone),
1596        );
1597        Arc::new(Self {
1598            common,
1599            is_connecting: Mutex::new(false),
1600            connection_streams: Mutex::new(HashMap::new()),
1601            configuration,
1602        })
1603    }
1604
1605    /// Establishes a WebSocket connection for the given streams.
1606    ///
1607    /// This method attempts to connect to a WebSocket server using the connection pool.
1608    /// If a connection is already established or in progress, it returns immediately.
1609    ///
1610    /// # Arguments
1611    ///
1612    /// * `streams` - A vector of stream identifiers to connect to
1613    ///
1614    /// # Returns
1615    ///
1616    /// A `Result` indicating whether the connection was successful or an error occurred
1617    ///
1618    /// # Errors
1619    ///
1620    /// Returns a `WebsocketError` if the connection fails or times out after 10 seconds
1621    pub async fn connect(self: Arc<Self>, streams: Vec<String>) -> Result<(), WebsocketError> {
1622        if self.common.is_connected(None).await {
1623            info!("WebSocket connection already established");
1624            return Ok(());
1625        }
1626
1627        {
1628            let mut flag = self.is_connecting.lock().await;
1629            if *flag {
1630                info!("Already connecting...");
1631                return Ok(());
1632            }
1633            *flag = true;
1634        }
1635
1636        let url = self.prepare_url(&streams);
1637
1638        let handler: Arc<dyn WebsocketHandler> = self.clone();
1639        for conn in &self.common.connection_pool {
1640            conn.set_handler(handler.clone()).await;
1641        }
1642
1643        let connect_res = select! {
1644            () = sleep(Duration::from_secs(10)) => Err(WebsocketError::Timeout),
1645            r = self.common.clone().connect_pool(&url) => r,
1646        };
1647
1648        {
1649            let mut flag = self.is_connecting.lock().await;
1650            *flag = false;
1651        }
1652
1653        connect_res
1654    }
1655
1656    /// Disconnects all WebSocket connections and clears associated state.
1657    ///
1658    /// # Returns
1659    ///
1660    /// A `Result` indicating whether the disconnection was successful or an error occurred
1661    ///
1662    /// # Errors
1663    ///
1664    /// Returns a `WebsocketError` if there are issues during the disconnection process
1665    ///
1666    /// # Side Effects
1667    ///
1668    /// - Clears stream callbacks for all connections
1669    /// - Clears pending subscriptions for all connections
1670    /// - Removes all connection stream mappings
1671    pub async fn disconnect(&self) -> Result<(), WebsocketError> {
1672        for connection in &self.common.connection_pool {
1673            let mut conn_state = connection.state.lock().await;
1674            conn_state.stream_callbacks.clear();
1675            conn_state.pending_subscriptions.clear();
1676        }
1677        self.connection_streams.lock().await.clear();
1678        self.common.disconnect().await
1679    }
1680
1681    /// Checks if the WebSocket connection is currently active.
1682    ///
1683    /// # Returns
1684    ///
1685    /// `true` if the WebSocket connection is established, `false` otherwise.
1686    pub async fn is_connected(&self) -> bool {
1687        self.common.is_connected(None).await
1688    }
1689
1690    /// Sends a ping to the WebSocket server to maintain the connection.
1691    ///
1692    /// This method delegates the ping operation to the underlying common WebSocket connection.
1693    /// It is typically used to keep the connection alive and check its status.
1694    ///
1695    /// # Side Effects
1696    ///
1697    /// Sends a ping request to the WebSocket server through the common connection.
1698    pub async fn ping_server(&self) {
1699        self.common.ping_server().await;
1700    }
1701
1702    /// Subscribes to multiple WebSocket streams, handling connection and queuing logic.
1703    ///
1704    /// # Arguments
1705    ///
1706    /// * `streams` - A vector of stream names to subscribe to
1707    /// * `id` - An optional request identifier for the subscription
1708    ///
1709    /// # Behavior
1710    ///
1711    /// - Filters out streams already subscribed
1712    /// - Assigns streams to appropriate connections
1713    /// - Handles subscription for active connections
1714    /// - Queues subscriptions for inactive connections
1715    ///
1716    /// # Side Effects
1717    ///
1718    /// - Sends subscription payloads for active connections
1719    /// - Adds pending subscriptions for inactive connections
1720    pub async fn subscribe(self: Arc<Self>, streams: Vec<String>, id: Option<String>) {
1721        let streams: Vec<String> = {
1722            let map = self.connection_streams.lock().await;
1723            streams
1724                .into_iter()
1725                .filter(|s| !map.contains_key(s))
1726                .collect()
1727        };
1728        let connection_streams = self.handle_stream_assignment(streams).await;
1729
1730        for (conn, streams) in connection_streams {
1731            if !self.common.is_connected(Some(&conn)).await {
1732                info!(
1733                    "Connection is not ready. Queuing subscription for streams: {:?}",
1734                    streams
1735                );
1736                let mut conn_state = conn.state.lock().await;
1737                conn_state.pending_subscriptions.extend(streams.clone());
1738                continue;
1739            }
1740            self.send_subscription_payload(&conn, &streams, id.clone());
1741        }
1742    }
1743
1744    /// Unsubscribes from specified WebSocket streams.
1745    ///
1746    /// # Arguments
1747    ///
1748    /// * `streams` - A vector of stream names to unsubscribe from
1749    /// * `id` - An optional request identifier for the unsubscription
1750    ///
1751    /// # Behavior
1752    ///
1753    /// - Validates the request identifier or generates a random one
1754    /// - Checks for active connections and subscribed streams
1755    /// - Sends unsubscribe payload for streams with active callbacks
1756    /// - Removes stream from connection streams and callbacks
1757    ///
1758    /// # Side Effects
1759    ///
1760    /// - Sends unsubscribe request to WebSocket server
1761    /// - Removes stream tracking from internal state
1762    ///
1763    /// # Async
1764    ///
1765    /// This method is asynchronous and requires `.await` when called
1766    ///
1767    /// # Panics
1768    ///
1769    /// This method may panic if the request identifier is not valid.
1770    ///
1771    pub async fn unsubscribe(&self, streams: Vec<String>, id: Option<String>) {
1772        let request_id = id
1773            .filter(|s| ID_REGEX.is_match(s))
1774            .unwrap_or_else(random_string);
1775
1776        for stream in streams {
1777            let maybe_conn = { self.connection_streams.lock().await.get(&stream).cloned() };
1778
1779            let conn = if let Some(c) = maybe_conn {
1780                if !self.common.is_connected(Some(&c)).await {
1781                    warn!(
1782                        "Stream {} not associated with an active connection.",
1783                        stream
1784                    );
1785                    continue;
1786                }
1787                c
1788            } else {
1789                warn!("Stream {} was not subscribed.", stream);
1790                continue;
1791            };
1792
1793            let callbacks = {
1794                let conn_state = conn.state.lock().await;
1795                conn_state
1796                    .stream_callbacks
1797                    .get(&stream)
1798                    .is_none_or(std::vec::Vec::is_empty)
1799            };
1800
1801            if !callbacks {
1802                continue;
1803            }
1804
1805            let payload = json!({
1806                "method": "UNSUBSCRIBE",
1807                "params": [stream.clone()],
1808                "id": request_id,
1809            });
1810
1811            info!("UNSUBSCRIBE → {:?}", payload);
1812
1813            let common = Arc::clone(&self.common);
1814            let conn_clone = Arc::clone(&conn);
1815            let msg = serde_json::to_string(&payload).unwrap();
1816            spawn(async move {
1817                let _ = common
1818                    .send(msg, None, false, Duration::ZERO, Some(conn_clone))
1819                    .await;
1820            });
1821
1822            {
1823                let mut connection_streams = self.connection_streams.lock().await;
1824                connection_streams.remove(&stream);
1825            }
1826            {
1827                let mut conn_state = conn.state.lock().await;
1828                conn_state.stream_callbacks.remove(&stream);
1829            }
1830        }
1831    }
1832
1833    /// Checks if a specific stream is currently subscribed.
1834    ///
1835    /// # Arguments
1836    ///
1837    /// * `stream` - The stream identifier to check for subscription status
1838    ///
1839    /// # Returns
1840    ///
1841    /// `true` if the stream is subscribed, `false` otherwise
1842    ///
1843    /// # Async
1844    ///
1845    /// This method is asynchronous and requires `.await` when called
1846    pub async fn is_subscribed(&self, stream: &str) -> bool {
1847        self.connection_streams.lock().await.contains_key(stream)
1848    }
1849
1850    /// Prepares a WebSocket URL for streaming with optional stream names and time unit configuration.
1851    ///
1852    /// # Arguments
1853    ///
1854    /// * `streams` - A slice of stream names to be included in the URL
1855    ///
1856    /// # Returns
1857    ///
1858    /// A fully constructed WebSocket URL with optional stream and time unit parameters
1859    ///
1860    /// # Notes
1861    ///
1862    /// - If no time unit is specified, the base URL is returned
1863    /// - Validates and appends the time unit parameter if provided and valid
1864    /// - Handles URL parameter separator based on existing query parameters
1865    fn prepare_url(&self, streams: &[String]) -> String {
1866        let mut url = format!(
1867            "{}/stream?streams={}",
1868            self.configuration.ws_url.as_deref().unwrap_or(""),
1869            streams.join("/")
1870        );
1871
1872        let time_unit = match &self.configuration.time_unit {
1873            Some(u) => u.to_string(),
1874            None => return url,
1875        };
1876
1877        match validate_time_unit(&time_unit) {
1878            Ok(Some(validated)) => {
1879                let sep = if url.contains('?') { '&' } else { '?' };
1880                url.push(sep);
1881                url.push_str("timeUnit=");
1882                url.push_str(validated);
1883            }
1884            Ok(None) => {}
1885            Err(e) => {
1886                error!("Invalid time unit provided: {:?}", e);
1887            }
1888        }
1889
1890        url
1891    }
1892
1893    /// Handles stream assignment by finding or creating WebSocket connections for a list of streams.
1894    ///
1895    /// This method attempts to assign streams to existing WebSocket connections or creates new
1896    /// connections if needed. It groups streams by their assigned connections and handles scenarios
1897    /// such as closed or pending reconnection connections.
1898    ///
1899    /// # Arguments
1900    ///
1901    /// * `streams` - A vector of stream names to be assigned
1902    ///
1903    /// # Returns
1904    ///
1905    /// A vector of tuples containing WebSocket connections and their associated streams
1906    ///
1907    /// # Errors
1908    ///
1909    /// Returns an empty result if no connections can be established for the streams
1910    async fn handle_stream_assignment(
1911        &self,
1912        streams: Vec<String>,
1913    ) -> Vec<(Arc<WebsocketConnection>, Vec<String>)> {
1914        let mut connection_streams: Vec<(String, Arc<WebsocketConnection>)> = Vec::new();
1915
1916        for stream in streams {
1917            let mut conn_opt = {
1918                let map = self.connection_streams.lock().await;
1919                map.get(&stream).cloned()
1920            };
1921
1922            let need_new = if let Some(conn) = &conn_opt {
1923                let state = conn.state.lock().await;
1924                state.close_initiated || state.reconnection_pending
1925            } else {
1926                true
1927            };
1928
1929            if need_new {
1930                match self.common.get_connection(true).await {
1931                    Ok(new_conn) => {
1932                        let mut map = self.connection_streams.lock().await;
1933                        map.insert(stream.clone(), new_conn.clone());
1934                        conn_opt = Some(new_conn);
1935                    }
1936                    Err(err) => {
1937                        warn!(
1938                            "No available WebSocket connection to subscribe stream `{}`: {:?}",
1939                            stream, err
1940                        );
1941                        continue;
1942                    }
1943                }
1944            }
1945
1946            if let Some(conn) = conn_opt {
1947                {
1948                    let mut conn_state = conn.state.lock().await;
1949                    conn_state
1950                        .stream_callbacks
1951                        .entry(stream.clone())
1952                        .or_default();
1953                }
1954                connection_streams.push((stream.clone(), conn));
1955            }
1956        }
1957
1958        let mut groups: Vec<(Arc<WebsocketConnection>, Vec<String>)> = Vec::new();
1959        for (stream, conn) in connection_streams {
1960            if let Some((_, vec)) = groups.iter_mut().find(|(c, _)| Arc::ptr_eq(c, &conn)) {
1961                vec.push(stream);
1962            } else {
1963                groups.push((conn, vec![stream]));
1964            }
1965        }
1966
1967        groups
1968    }
1969
1970    /// Sends a WebSocket subscription payload for the specified streams.
1971    ///
1972    /// # Arguments
1973    ///
1974    /// * `connection` - The WebSocket connection to send the subscription on
1975    /// * `streams` - A vector of stream names to subscribe to
1976    /// * `id` - An optional request ID for the subscription (will be randomly generated if not provided)
1977    ///
1978    /// # Remarks
1979    ///
1980    /// This method constructs a SUBSCRIBE payload, logs it, and sends it asynchronously using the WebSocket connection.
1981    /// If serialization fails, an error is logged and the method returns without sending.
1982    fn send_subscription_payload(
1983        &self,
1984        connection: &Arc<WebsocketConnection>,
1985        streams: &Vec<String>,
1986        id: Option<String>,
1987    ) {
1988        let request_id = id
1989            .filter(|s| ID_REGEX.is_match(s))
1990            .unwrap_or_else(random_string);
1991
1992        let payload = json!({
1993            "method": "SUBSCRIBE",
1994            "params": streams,
1995            "id": request_id,
1996        });
1997
1998        info!("SUBSCRIBE → {:?}", payload);
1999
2000        let common = Arc::clone(&self.common);
2001        let msg = match serde_json::to_string(&payload) {
2002            Ok(s) => s,
2003            Err(e) => {
2004                error!("Failed to serialize SUBSCRIBE payload: {}", e);
2005                return;
2006            }
2007        };
2008        let conn_clone = Arc::clone(connection);
2009
2010        spawn(async move {
2011            let _ = common
2012                .send(msg, None, false, Duration::ZERO, Some(conn_clone))
2013                .await;
2014        });
2015    }
2016}
2017
2018#[async_trait]
2019impl WebsocketHandler for WebsocketStreams {
2020    /// Handles the WebSocket connection opening by processing any pending subscriptions.
2021    ///
2022    /// This method is called when a WebSocket connection is established. It retrieves
2023    /// any pending stream subscriptions from the connection state and sends them
2024    /// immediately using the `send_subscription_payload` method.
2025    ///
2026    /// # Arguments
2027    ///
2028    /// * `_url` - The URL of the WebSocket connection (unused)
2029    /// * `connection` - The WebSocket connection that has just been opened
2030    ///
2031    /// # Remarks
2032    ///
2033    /// If there are any pending subscriptions, they are sent as a batch subscription
2034    /// payload. The method uses a lock to safely access and clear the pending subscriptions
2035    /// from the connection state.
2036    async fn on_open(&self, _url: String, connection: Arc<WebsocketConnection>) {
2037        let pending_subs: Vec<String> = {
2038            let mut conn_state = connection.state.lock().await;
2039            take(&mut conn_state.pending_subscriptions)
2040                .into_iter()
2041                .collect()
2042        };
2043
2044        if !pending_subs.is_empty() {
2045            info!("Processing queued subscriptions for connection");
2046            self.send_subscription_payload(&connection, &pending_subs, None);
2047        }
2048    }
2049
2050    /// Handles incoming WebSocket stream messages by parsing the JSON payload and invoking registered stream callbacks.
2051    ///
2052    /// This method processes WebSocket messages with a specific structure, extracting the stream name and data.
2053    /// It retrieves and executes any registered callbacks associated with the stream name.
2054    ///
2055    /// # Arguments
2056    ///
2057    /// * `data` - The raw WebSocket message as a JSON-formatted string
2058    /// * `connection` - The WebSocket connection through which the message was received
2059    ///
2060    /// # Behavior
2061    ///
2062    /// - Parses the JSON message
2063    /// - Extracts the stream name and data payload
2064    /// - Looks up and invokes any registered callbacks for the stream
2065    /// - Silently returns if message parsing or stream extraction fails
2066    async fn on_message(&self, data: String, connection: Arc<WebsocketConnection>) {
2067        let msg: Value = match serde_json::from_str(&data) {
2068            Ok(v) => v,
2069            Err(err) => {
2070                error!(
2071                    "Failed to parse WebSocket stream message {} – {}",
2072                    data, err
2073                );
2074                return;
2075            }
2076        };
2077
2078        let (stream_name, payload) = match (
2079            msg.get("stream").and_then(Value::as_str),
2080            msg.get("data").cloned(),
2081        ) {
2082            (Some(name), Some(data)) => (name.to_string(), data),
2083            _ => return,
2084        };
2085
2086        let callbacks = {
2087            let conn_state = connection.state.lock().await;
2088            conn_state
2089                .stream_callbacks
2090                .get(&stream_name)
2091                .cloned()
2092                .unwrap_or_else(Vec::new)
2093        };
2094
2095        for callback in callbacks {
2096            callback(&payload);
2097        }
2098    }
2099
2100    /// Retrieves the reconnection URL for a specific WebSocket connection by identifying all streams associated with that connection.
2101    ///
2102    /// # Arguments
2103    ///
2104    /// * `_default_url` - A default URL that can be used if no specific reconnection URL is determined
2105    /// * `connection` - The WebSocket connection for which to generate a reconnection URL
2106    ///
2107    /// # Returns
2108    ///
2109    /// A URL string that can be used to reconnect to the WebSocket, based on the streams associated with the given connection
2110    async fn get_reconnect_url(
2111        &self,
2112        _default_url: String,
2113        connection: Arc<WebsocketConnection>,
2114    ) -> String {
2115        let connection_streams = self.connection_streams.lock().await;
2116        let reconnect_streams = connection_streams
2117            .iter()
2118            .filter_map(|(stream, conn_arc)| {
2119                if Arc::ptr_eq(conn_arc, &connection) {
2120                    Some(stream.clone())
2121                } else {
2122                    None
2123                }
2124            })
2125            .collect::<Vec<_>>();
2126        self.prepare_url(&reconnect_streams)
2127    }
2128}
2129
2130pub struct WebsocketStream<T> {
2131    websocket_base: WebsocketBase,
2132    stream_or_id: String,
2133    callback: Mutex<Option<Arc<dyn Fn(&Value) + Send + Sync>>>,
2134    pub id: Option<String>,
2135    _phantom: PhantomData<T>,
2136}
2137
2138impl<T> WebsocketStream<T>
2139where
2140    T: DeserializeOwned + Send + 'static,
2141{
2142    /// Registers a callback function for a specific event on the WebSocket stream.
2143    ///
2144    /// This method currently only supports the "message" event. When a message is received,
2145    /// the provided callback function will be invoked with the deserialized payload.
2146    ///
2147    /// # Arguments
2148    ///
2149    /// * `event` - The event type to listen for (currently only "message" is supported)
2150    /// * `callback_fn` - A function that will be called with the deserialized message payload
2151    ///
2152    /// # Errors
2153    ///
2154    /// Logs an error if the payload cannot be deserialized into the expected type
2155    ///
2156    /// # Examples
2157    ///
2158    ///
2159    /// stream.on("message", |data: `MyType`| {
2160    ///     // Handle the deserialized message
2161    /// });
2162    async fn on<F>(&self, event: &str, callback_fn: F)
2163    where
2164        F: Fn(T) + Send + Sync + 'static,
2165    {
2166        if event != "message" {
2167            return;
2168        }
2169
2170        let cb_wrapper: Arc<dyn Fn(&Value) + Send + Sync> =
2171            Arc::new(
2172                move |v: &Value| match serde_json::from_value::<T>(v.clone()) {
2173                    Ok(data) => callback_fn(data),
2174                    Err(e) => error!("Failed to deserialize stream payload: {:?}", e),
2175                },
2176            );
2177
2178        {
2179            let mut guard = self.callback.lock().await;
2180            *guard = Some(cb_wrapper.clone());
2181        }
2182
2183        match &self.websocket_base {
2184            WebsocketBase::WebsocketStreams(ws_streams) => {
2185                let conn = {
2186                    let map = ws_streams.connection_streams.lock().await;
2187                    map.get(&self.stream_or_id)
2188                        .cloned()
2189                        .expect("stream must be subscribed")
2190                };
2191
2192                {
2193                    let mut conn_state = conn.state.lock().await;
2194                    let entry = conn_state
2195                        .stream_callbacks
2196                        .entry(self.stream_or_id.clone())
2197                        .or_default();
2198
2199                    if !entry
2200                        .iter()
2201                        .any(|existing| Arc::ptr_eq(existing, &cb_wrapper))
2202                    {
2203                        entry.push(cb_wrapper);
2204                    }
2205                }
2206            }
2207            WebsocketBase::WebsocketApi(ws_api) => {
2208                let mut stream_callbacks = ws_api.stream_callbacks.lock().await;
2209                let entry = stream_callbacks
2210                    .entry(self.stream_or_id.clone())
2211                    .or_default();
2212
2213                if !entry
2214                    .iter()
2215                    .any(|existing| Arc::ptr_eq(existing, &cb_wrapper))
2216                {
2217                    entry.push(cb_wrapper);
2218                }
2219            }
2220        }
2221    }
2222
2223    /// Synchronously sets a message callback for the WebSocket stream on the current thread.
2224    ///
2225    /// # Arguments
2226    ///
2227    /// * `callback_fn` - A function that will be called with the deserialized message payload
2228    ///
2229    /// # Panics
2230    ///
2231    /// Panics if the thread runtime fails to be created or if the thread join fails
2232    ///
2233    /// # Examples
2234    ///
2235    ///
2236    /// let stream = `Arc::new(WebsocketStream::new())`;
2237    /// `stream.on_message(|data`: `MyType`| {
2238    ///     // Handle the deserialized message
2239    /// });
2240    ///
2241    pub fn on_message<F>(self: &Arc<Self>, callback_fn: F)
2242    where
2243        T: Send + Sync,
2244        F: Fn(T) + Send + Sync + 'static,
2245    {
2246        let handler: Arc<Self> = Arc::clone(self);
2247
2248        std::thread::spawn(move || {
2249            let rt = tokio::runtime::Builder::new_current_thread()
2250                .enable_all()
2251                .build()
2252                .expect("failed to build Tokio runtime");
2253
2254            rt.block_on(handler.on("message", callback_fn));
2255        })
2256        .join()
2257        .expect("on_message thread panicked");
2258    }
2259
2260    /// Unsubscribes from the current WebSocket stream and removes the associated callback.
2261    ///
2262    /// This method performs the following actions:
2263    /// - Removes the current callback associated with the stream
2264    /// - Removes the callback from the connection's stream callbacks
2265    /// - Asynchronously unsubscribes from the stream using the WebSocket streams base
2266    ///
2267    /// # Panics
2268    ///
2269    /// Panics if the stream is not subscribed to
2270    ///
2271    /// # Notes
2272    /// - If no callback is present, no action is taken
2273    /// - Spawns an asynchronous task to handle the unsubscription process
2274    pub async fn unsubscribe(&self) {
2275        let maybe_cb = {
2276            let mut guard = self.callback.lock().await;
2277            guard.take()
2278        };
2279
2280        if let Some(cb) = maybe_cb {
2281            match &self.websocket_base {
2282                WebsocketBase::WebsocketStreams(ws_streams) => {
2283                    let conn = {
2284                        let map = ws_streams.connection_streams.lock().await;
2285                        map.get(&self.stream_or_id)
2286                            .cloned()
2287                            .expect("stream must have been subscribed")
2288                    };
2289
2290                    {
2291                        let mut conn_state = conn.state.lock().await;
2292                        if let Some(list) = conn_state.stream_callbacks.get_mut(&self.stream_or_id)
2293                        {
2294                            list.retain(|existing| !Arc::ptr_eq(existing, &cb));
2295                        }
2296                    }
2297
2298                    let stream = self.stream_or_id.clone();
2299                    let id = self.id.clone();
2300                    let websocket_streams_base = Arc::clone(ws_streams);
2301                    spawn(async move {
2302                        websocket_streams_base.unsubscribe(vec![stream], id).await;
2303                    });
2304                }
2305                WebsocketBase::WebsocketApi(ws_api) => {
2306                    let mut stream_callbacks = ws_api.stream_callbacks.lock().await;
2307                    if let Some(list) = stream_callbacks.get_mut(&self.stream_or_id) {
2308                        list.retain(|existing| !Arc::ptr_eq(existing, &cb));
2309                    }
2310                }
2311            }
2312        }
2313    }
2314}
2315
2316pub async fn create_stream_handler<T>(
2317    websocket_base: WebsocketBase,
2318    stream_or_id: String,
2319    id: Option<String>,
2320) -> Arc<WebsocketStream<T>>
2321where
2322    T: DeserializeOwned + Send + 'static,
2323{
2324    match &websocket_base {
2325        WebsocketBase::WebsocketStreams(ws_streams) => {
2326            ws_streams
2327                .clone()
2328                .subscribe(vec![stream_or_id.clone()], id.clone())
2329                .await;
2330        }
2331        WebsocketBase::WebsocketApi(_) => {}
2332    }
2333
2334    Arc::new(WebsocketStream {
2335        websocket_base,
2336        stream_or_id,
2337        id,
2338        callback: Mutex::new(None),
2339        _phantom: PhantomData,
2340    })
2341}
2342
2343#[cfg(test)]
2344mod tests {
2345    use crate::TOKIO_SHARED_RT;
2346    use crate::common::utils::{SignatureGenerator, build_user_agent};
2347    use crate::common::websocket::{
2348        PendingRequest, ReconnectEntry, WebsocketApi, WebsocketBase, WebsocketCommon,
2349        WebsocketConnection, WebsocketEvent, WebsocketEventEmitter, WebsocketHandler,
2350        WebsocketMessageSendOptions, WebsocketMode, WebsocketStream, WebsocketStreams,
2351        create_stream_handler,
2352    };
2353    use crate::config::{ConfigurationWebsocketApi, ConfigurationWebsocketStreams, PrivateKey};
2354    use crate::errors::WebsocketError;
2355    use crate::models::TimeUnit;
2356    use async_trait::async_trait;
2357    use futures::{SinkExt, StreamExt};
2358    use http::header::USER_AGENT;
2359    use regex::Regex;
2360    use serde_json::{Value, json};
2361    use std::collections::{BTreeMap, HashSet};
2362    use std::marker::PhantomData;
2363    use std::net::SocketAddr;
2364    use std::sync::{
2365        Arc,
2366        atomic::{AtomicBool, AtomicUsize, Ordering},
2367    };
2368    use tokio::net::TcpListener;
2369    use tokio::sync::{Mutex, mpsc::unbounded_channel, oneshot};
2370    use tokio::time::{Duration, advance, pause, resume, sleep, timeout};
2371    use tokio_tungstenite::{accept_async, accept_hdr_async, tungstenite, tungstenite::Message};
2372    use tungstenite::handshake::server::Request;
2373
2374    fn subscribe_events(common: &WebsocketCommon) -> Arc<Mutex<Vec<WebsocketEvent>>> {
2375        let events = Arc::new(Mutex::new(Vec::new()));
2376        let events_clone = events.clone();
2377        common.events.subscribe(move |event| {
2378            let events_clone = events_clone.clone();
2379            tokio::spawn(async move {
2380                events_clone.lock().await.push(event);
2381            });
2382        });
2383        events
2384    }
2385
2386    async fn create_connection(
2387        id: &str,
2388        has_writer: bool,
2389        reconnection_pending: bool,
2390        renewal_pending: bool,
2391        close_initiated: bool,
2392    ) -> Arc<WebsocketConnection> {
2393        let conn = WebsocketConnection::new(id);
2394        let mut st = conn.state.lock().await;
2395        st.reconnection_pending = reconnection_pending;
2396        st.renewal_pending = renewal_pending;
2397        st.close_initiated = close_initiated;
2398        if has_writer {
2399            let (tx, _) = unbounded_channel::<Message>();
2400            st.ws_write_tx = Some(tx);
2401        } else {
2402            st.ws_write_tx = None;
2403        }
2404        drop(st);
2405        conn
2406    }
2407
2408    fn create_websocket_api(time_unit: Option<TimeUnit>) -> Arc<WebsocketApi> {
2409        let sig_gen = SignatureGenerator::new(
2410            Some("api_secret".into()),
2411            None::<PrivateKey>,
2412            None::<String>,
2413        );
2414        let config = ConfigurationWebsocketApi {
2415            api_key: Some("api_key".into()),
2416            api_secret: Some("api_secret".into()),
2417            private_key: None,
2418            private_key_passphrase: None,
2419            ws_url: Some("wss://example.com".into()),
2420            mode: WebsocketMode::Single,
2421            reconnect_delay: 1000,
2422            signature_gen: sig_gen,
2423            timeout: 500,
2424            time_unit,
2425            agent: None,
2426            user_agent: build_user_agent("product"),
2427        };
2428        let conn = WebsocketConnection::new("c1");
2429        WebsocketApi::new(config, vec![conn])
2430    }
2431
2432    fn create_websocket_streams(
2433        ws_url: Option<&str>,
2434        conns: Option<Vec<Arc<WebsocketConnection>>>,
2435    ) -> Arc<WebsocketStreams> {
2436        let mut connections: Vec<Arc<WebsocketConnection>> = vec![];
2437        if conns.is_none() {
2438            connections.push(WebsocketConnection::new("c1"));
2439            connections.push(WebsocketConnection::new("c2"));
2440        } else {
2441            connections = conns.expect("Expected connections to be set");
2442        }
2443        let config = ConfigurationWebsocketStreams {
2444            ws_url: Some(ws_url.unwrap_or("example.com").to_string()),
2445            mode: WebsocketMode::Single,
2446            reconnect_delay: 500,
2447            time_unit: None,
2448            agent: None,
2449            user_agent: build_user_agent("product"),
2450        };
2451        WebsocketStreams::new(config, connections)
2452    }
2453
2454    mod event_emitter {
2455        use super::*;
2456
2457        #[test]
2458        fn event_emitter_subscribe_and_emit() {
2459            TOKIO_SHARED_RT.block_on(async {
2460                let emitter = WebsocketEventEmitter::new();
2461                let (tx, rx) = oneshot::channel();
2462                let tx = Arc::new(std::sync::Mutex::new(Some(tx)));
2463                let tx_clone = tx.clone();
2464                let _sub = emitter.subscribe(move |event| {
2465                    if let Some(sender) = tx_clone.lock().unwrap().take() {
2466                        let _ = sender.send(event);
2467                    }
2468                });
2469                emitter.emit(WebsocketEvent::Open);
2470                let received = timeout(Duration::from_millis(100), rx)
2471                    .await
2472                    .expect("timed out");
2473                assert_eq!(received, Ok(WebsocketEvent::Open));
2474            });
2475        }
2476    }
2477
2478    mod websocket_common {
2479        use super::*;
2480
2481        mod initialisation {
2482            use super::*;
2483
2484            #[test]
2485            fn single_mode() {
2486                TOKIO_SHARED_RT.block_on(async {
2487                    let common = WebsocketCommon::new(vec![], WebsocketMode::Single, 0, None, None);
2488                    assert_eq!(common.connection_pool.len(), 1);
2489                });
2490            }
2491
2492            #[test]
2493            fn pool_mode() {
2494                TOKIO_SHARED_RT.block_on(async {
2495                    let common =
2496                        WebsocketCommon::new(vec![], WebsocketMode::Pool(3), 0, None, None);
2497                    assert_eq!(common.connection_pool.len(), 3);
2498                });
2499            }
2500        }
2501
2502        mod spawn_reconnect_loop {
2503            use super::*;
2504
2505            #[test]
2506            fn successful_reconnect_entry_triggers_init_connect() {
2507                TOKIO_SHARED_RT.block_on(async {
2508                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2509                    let addr = listener.local_addr().unwrap();
2510                    tokio::spawn(async move {
2511                        if let Ok((stream, _)) = listener.accept().await {
2512                            let mut ws = accept_async(stream).await.unwrap();
2513                            let _ = ws.close(None).await;
2514                        }
2515                    });
2516
2517                    let conn = WebsocketConnection::new("c1");
2518                    let common = WebsocketCommon::new(
2519                        vec![conn.clone()],
2520                        WebsocketMode::Single,
2521                        10,
2522                        None,
2523                        None,
2524                    );
2525                    let url = format!("ws://{addr}");
2526                    common
2527                        .reconnect_tx
2528                        .send(ReconnectEntry {
2529                            connection_id: "c1".into(),
2530                            url: url.clone(),
2531                            is_renewal: false,
2532                        })
2533                        .await
2534                        .unwrap();
2535
2536                    sleep(Duration::from_secs(2)).await;
2537
2538                    let st = conn.state.lock().await;
2539                    assert!(st.ws_write_tx.is_some());
2540                });
2541            }
2542
2543            #[test]
2544            fn reconnect_entry_with_unknown_id_is_ignored() {
2545                TOKIO_SHARED_RT.block_on(async {
2546                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2547                    let addr = listener.local_addr().unwrap();
2548                    tokio::spawn(async move {
2549                        if let Ok((stream, _)) = listener.accept().await {
2550                            let mut ws = accept_async(stream).await.unwrap();
2551                            let _ = ws.close(None).await;
2552                        }
2553                    });
2554
2555                    let conn = WebsocketConnection::new("c1");
2556                    let common = WebsocketCommon::new(
2557                        vec![conn.clone()],
2558                        WebsocketMode::Single,
2559                        5,
2560                        None,
2561                        None,
2562                    );
2563                    let url = format!("ws://{addr}");
2564                    common
2565                        .reconnect_tx
2566                        .send(ReconnectEntry {
2567                            connection_id: "other".into(),
2568                            url,
2569                            is_renewal: false,
2570                        })
2571                        .await
2572                        .unwrap();
2573
2574                    sleep(Duration::from_secs(1)).await;
2575
2576                    let st = conn.state.lock().await;
2577                    assert!(st.ws_write_tx.is_none());
2578                });
2579            }
2580
2581            #[test]
2582            fn renewal_entries_bypass_initial_delay() {
2583                TOKIO_SHARED_RT.block_on(async {
2584                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2585                    let addr = listener.local_addr().unwrap();
2586                    tokio::spawn(async move {
2587                        if let Ok((stream, _)) = listener.accept().await {
2588                            let mut ws = accept_async(stream).await.unwrap();
2589                            let _ = ws.close(None).await;
2590                        }
2591                    });
2592
2593                    let conn = WebsocketConnection::new("renew");
2594                    let common = WebsocketCommon::new(
2595                        vec![conn.clone()],
2596                        WebsocketMode::Single,
2597                        200,
2598                        None,
2599                        None,
2600                    );
2601                    let url = format!("ws://{addr}");
2602                    common
2603                        .reconnect_tx
2604                        .send(ReconnectEntry {
2605                            connection_id: "renew".into(),
2606                            url: url.clone(),
2607                            is_renewal: true,
2608                        })
2609                        .await
2610                        .unwrap();
2611
2612                    sleep(Duration::from_secs(2)).await;
2613
2614                    let st = conn.state.lock().await;
2615
2616                    assert!(st.ws_write_tx.is_some());
2617                });
2618            }
2619
2620            #[test]
2621            fn non_renewal_entries_respect_initial_delay() {
2622                TOKIO_SHARED_RT.block_on(async {
2623                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2624                    let addr = listener.local_addr().unwrap();
2625                    tokio::spawn(async move {
2626                        if let Ok((stream, _)) = listener.accept().await {
2627                            let mut ws = accept_async(stream).await.unwrap();
2628                            let _ = ws.close(None).await;
2629                        }
2630                    });
2631
2632                    let conn = WebsocketConnection::new("nonrenew");
2633                    let common = WebsocketCommon::new(
2634                        vec![conn.clone()],
2635                        WebsocketMode::Single,
2636                        200,
2637                        None,
2638                        None,
2639                    );
2640                    let url = format!("ws://{addr}");
2641                    common
2642                        .reconnect_tx
2643                        .send(ReconnectEntry {
2644                            connection_id: "nonrenew".into(),
2645                            url: url.clone(),
2646                            is_renewal: false,
2647                        })
2648                        .await
2649                        .unwrap();
2650
2651                    sleep(Duration::from_millis(100)).await;
2652                    assert!(conn.state.lock().await.ws_write_tx.is_none());
2653
2654                    sleep(Duration::from_secs(2)).await;
2655
2656                    assert!(conn.state.lock().await.ws_write_tx.is_some());
2657                });
2658            }
2659        }
2660
2661        mod spawn_renewal_loop {
2662            use super::*;
2663
2664            #[tokio::test]
2665            async fn scheduling_renewal_does_not_panic_for_known_connection() {
2666                pause();
2667
2668                let conn = WebsocketConnection::new("known");
2669                let common =
2670                    WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
2671                let url = "wss://example".to_string();
2672                common
2673                    .renewal_tx
2674                    .send((conn.id.clone(), url))
2675                    .await
2676                    .unwrap();
2677                advance(Duration::from_secs(23 * 60 * 60 + 1)).await;
2678
2679                resume();
2680            }
2681
2682            #[tokio::test]
2683            async fn scheduling_renewal_ignored_for_unknown_connection() {
2684                pause();
2685
2686                let conn = WebsocketConnection::new("c1");
2687                let common =
2688                    WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
2689                common
2690                    .renewal_tx
2691                    .send(("other".into(), "u".into()))
2692                    .await
2693                    .unwrap();
2694                advance(Duration::from_secs(23 * 60 * 60 + 1)).await;
2695
2696                resume();
2697            }
2698        }
2699
2700        mod is_connection_ready {
2701            use super::*;
2702
2703            #[test]
2704            fn is_connection_ready() {
2705                TOKIO_SHARED_RT.block_on(async {
2706                    let conn = WebsocketConnection::new("c1");
2707                    let common = WebsocketCommon::new(
2708                        vec![conn.clone()],
2709                        WebsocketMode::Single,
2710                        0,
2711                        None,
2712                        None,
2713                    );
2714                    assert!(!common.is_connection_ready(&conn, false).await);
2715                    assert!(common.is_connection_ready(&conn, true).await);
2716                });
2717            }
2718
2719            #[test]
2720            fn connection_ready_basic() {
2721                TOKIO_SHARED_RT.block_on(async {
2722                    let conn = create_connection("c1", true, false, false, false).await;
2723                    let common = WebsocketCommon::new(
2724                        vec![conn.clone()],
2725                        WebsocketMode::Single,
2726                        0,
2727                        None,
2728                        None,
2729                    );
2730                    assert!(common.is_connection_ready(&conn, false).await);
2731                });
2732            }
2733
2734            #[test]
2735            fn connection_not_ready_without_writer() {
2736                TOKIO_SHARED_RT.block_on(async {
2737                    let conn = create_connection("c1", false, false, false, false).await;
2738                    let common = WebsocketCommon::new(
2739                        vec![conn.clone()],
2740                        WebsocketMode::Single,
2741                        0,
2742                        None,
2743                        None,
2744                    );
2745                    assert!(!common.is_connection_ready(&conn, false).await);
2746                    assert!(common.is_connection_ready(&conn, true).await);
2747                });
2748            }
2749
2750            #[test]
2751            fn connection_not_ready_when_flagged() {
2752                TOKIO_SHARED_RT.block_on(async {
2753                    let conn1 = create_connection("c1", true, true, false, false).await;
2754                    let conn2 = create_connection("c2", true, false, true, false).await;
2755                    let conn3 = create_connection("c3", true, false, false, true).await;
2756
2757                    let common = WebsocketCommon::new(
2758                        vec![conn1.clone(), conn2.clone(), conn3.clone()],
2759                        WebsocketMode::Pool(3),
2760                        0,
2761                        None,
2762                        None,
2763                    );
2764
2765                    assert!(!common.is_connection_ready(&conn1, false).await);
2766                    assert!(!common.is_connection_ready(&conn2, false).await);
2767                    assert!(!common.is_connection_ready(&conn3, false).await);
2768                });
2769            }
2770        }
2771
2772        mod is_connected {
2773            use super::*;
2774
2775            #[test]
2776            fn with_pool_various_connections() {
2777                TOKIO_SHARED_RT.block_on(async {
2778                    let conn_a = create_connection("a", true, false, false, false).await;
2779                    let conn_b = create_connection("b", false, false, false, false).await;
2780                    let conn_c = create_connection("c", true, true, false, false).await;
2781                    let pool = vec![conn_a.clone(), conn_b.clone(), conn_c.clone()];
2782                    let common = WebsocketCommon::new(pool, WebsocketMode::Pool(3), 0, None, None);
2783
2784                    assert!(common.is_connected(None).await);
2785                    assert!(common.is_connected(Some(&conn_a)).await);
2786                    assert!(!common.is_connected(Some(&conn_b)).await);
2787                    assert!(!common.is_connected(Some(&conn_c)).await);
2788                });
2789            }
2790
2791            #[test]
2792            fn with_pool_all_bad_connections() {
2793                TOKIO_SHARED_RT.block_on(async {
2794                    let bad1 = create_connection("c1", false, false, false, false).await;
2795                    let bad2 = create_connection("c2", true, true, false, false).await;
2796                    let bad3 = create_connection("c3", true, false, false, true).await;
2797                    let common = WebsocketCommon::new(
2798                        vec![bad1, bad2, bad3],
2799                        WebsocketMode::Pool(3),
2800                        0,
2801                        None,
2802                        None,
2803                    );
2804
2805                    assert!(!common.is_connected(None).await);
2806                });
2807            }
2808
2809            #[test]
2810            fn with_pool_ignore_close_initiated() {
2811                TOKIO_SHARED_RT.block_on(async {
2812                    let good = create_connection("c1", true, false, false, false).await;
2813                    let closed = create_connection("c2", true, false, false, true).await;
2814                    let bad = create_connection("c3", false, false, false, false).await;
2815                    let common = WebsocketCommon::new(
2816                        vec![closed.clone(), good.clone(), bad.clone()],
2817                        WebsocketMode::Pool(3),
2818                        0,
2819                        None,
2820                        None,
2821                    );
2822
2823                    assert!(common.is_connected(None).await);
2824                    assert!(!common.is_connected(Some(&closed)).await);
2825                });
2826            }
2827        }
2828
2829        mod get_connection {
2830            use super::*;
2831
2832            #[test]
2833            fn single_mode() {
2834                TOKIO_SHARED_RT.block_on(async {
2835                    let common = WebsocketCommon::new(vec![], WebsocketMode::Single, 0, None, None);
2836                    let conn = common
2837                        .get_connection(false)
2838                        .await
2839                        .expect("should get connection");
2840                    assert_eq!(conn.id, common.connection_pool[0].id);
2841                });
2842            }
2843
2844            #[test]
2845            fn pool_mode_not_ready() {
2846                TOKIO_SHARED_RT.block_on(async {
2847                    let common =
2848                        WebsocketCommon::new(vec![], WebsocketMode::Pool(2), 0, None, None);
2849                    let result = common.get_connection(false).await;
2850                    assert!(matches!(
2851                        result,
2852                        Err(crate::errors::WebsocketError::NotConnected)
2853                    ));
2854                });
2855            }
2856
2857            #[test]
2858            fn pool_mode_with_ready() {
2859                TOKIO_SHARED_RT.block_on(async {
2860                    let conn1 = WebsocketConnection::new("c1");
2861                    let conn2 = WebsocketConnection::new("c2");
2862                    let (tx1, _rx1) = unbounded_channel();
2863                    {
2864                        let mut s1 = conn1.state.lock().await;
2865                        s1.ws_write_tx = Some(tx1);
2866                    }
2867                    let pool = vec![conn1.clone(), conn2.clone()];
2868                    let common = WebsocketCommon::new(pool, WebsocketMode::Pool(2), 0, None, None);
2869                    let result = common.get_connection(false).await;
2870                    assert!(result.is_ok());
2871                    let chosen = result.unwrap();
2872                    assert_eq!(chosen.id, conn1.id);
2873                });
2874            }
2875        }
2876
2877        mod close_connection_gracefully {
2878            use super::*;
2879
2880            #[tokio::test]
2881            async fn waits_for_pending_requests_then_closes() {
2882                pause();
2883
2884                let conn = WebsocketConnection::new("c1");
2885                let (tx, mut rx) = unbounded_channel::<Message>();
2886                let (req_tx, _req_rx) = oneshot::channel();
2887                {
2888                    let mut st = conn.state.lock().await;
2889                    st.pending_requests
2890                        .insert("r".to_string(), PendingRequest { completion: req_tx });
2891                }
2892                let common =
2893                    WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
2894                let close_fut = common.close_connection_gracefully(tx.clone(), conn.clone());
2895                advance(Duration::from_secs(1)).await;
2896                {
2897                    let mut st = conn.state.lock().await;
2898                    st.pending_requests.clear();
2899                }
2900                conn.drain_notify.notify_waiters();
2901                advance(Duration::from_secs(1)).await;
2902                close_fut.await.unwrap();
2903                match rx.try_recv() {
2904                    Ok(Message::Close(_)) => {}
2905                    other => panic!("expected Close, got {other:?}"),
2906                }
2907
2908                resume();
2909            }
2910
2911            #[tokio::test]
2912            async fn force_closes_after_timeout() {
2913                pause();
2914
2915                let conn = WebsocketConnection::new("c2");
2916                let (tx, mut rx) = unbounded_channel::<Message>();
2917                let (req_tx, _req_rx) = oneshot::channel();
2918                {
2919                    let mut st = conn.state.lock().await;
2920                    st.pending_requests.insert(
2921                        "request_id".to_string(),
2922                        PendingRequest { completion: req_tx },
2923                    );
2924                }
2925                let common =
2926                    WebsocketCommon::new(vec![conn.clone()], WebsocketMode::Single, 0, None, None);
2927                let close_fut = common.close_connection_gracefully(tx.clone(), conn.clone());
2928                advance(Duration::from_secs(30)).await;
2929                close_fut.await.unwrap();
2930                match rx.try_recv() {
2931                    Ok(Message::Close(_)) => {}
2932                    other => panic!("expected Close on timeout, got {other:?}"),
2933                }
2934
2935                resume();
2936            }
2937        }
2938
2939        mod get_reconnect_url {
2940            use super::*;
2941
2942            struct DummyHandler {
2943                url: String,
2944            }
2945
2946            #[async_trait::async_trait]
2947            impl WebsocketHandler for DummyHandler {
2948                async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {}
2949                async fn on_message(&self, _data: String, _connection: Arc<WebsocketConnection>) {}
2950                async fn get_reconnect_url(
2951                    &self,
2952                    _default_url: String,
2953                    _connection: Arc<WebsocketConnection>,
2954                ) -> String {
2955                    self.url.clone()
2956                }
2957            }
2958
2959            #[test]
2960            fn returns_default_when_no_handler() {
2961                TOKIO_SHARED_RT.block_on(async {
2962                    let conn = WebsocketConnection::new("c1");
2963                    let common = WebsocketCommon::new(
2964                        vec![conn.clone()],
2965                        WebsocketMode::Single,
2966                        0,
2967                        None,
2968                        None,
2969                    );
2970                    let default = "wss://default".to_string();
2971                    let result = common.get_reconnect_url(&default, conn.clone()).await;
2972                    assert_eq!(result, default);
2973                });
2974            }
2975
2976            #[test]
2977            fn returns_handler_url_when_set() {
2978                TOKIO_SHARED_RT.block_on(async {
2979                    let conn = WebsocketConnection::new("c2");
2980                    let handler = Arc::new(DummyHandler {
2981                        url: "wss://custom".into(),
2982                    });
2983                    conn.set_handler(handler).await;
2984                    let common = WebsocketCommon::new(
2985                        vec![conn.clone()],
2986                        WebsocketMode::Single,
2987                        0,
2988                        None,
2989                        None,
2990                    );
2991                    let default = "wss://default".to_string();
2992                    let result = common.get_reconnect_url(&default, conn.clone()).await;
2993                    assert_eq!(result, "wss://custom");
2994                });
2995            }
2996        }
2997
2998        mod on_open {
2999            use super::*;
3000
3001            struct DummyHandler {
3002                called: Arc<Mutex<bool>>,
3003                opened_url: Arc<Mutex<Option<String>>>,
3004            }
3005
3006            #[async_trait]
3007            impl WebsocketHandler for DummyHandler {
3008                async fn on_open(&self, url: String, _connection: Arc<WebsocketConnection>) {
3009                    let mut flag = self.called.lock().await;
3010                    *flag = true;
3011                    let mut store = self.opened_url.lock().await;
3012                    *store = Some(url);
3013                }
3014                async fn on_message(&self, _data: String, _connection: Arc<WebsocketConnection>) {}
3015                async fn get_reconnect_url(
3016                    &self,
3017                    default_url: String,
3018                    _connection: Arc<WebsocketConnection>,
3019                ) -> String {
3020                    default_url
3021                }
3022            }
3023
3024            #[test]
3025            fn emits_open_and_calls_handler() {
3026                TOKIO_SHARED_RT.block_on(async {
3027                    let conn = WebsocketConnection::new("c1");
3028                    let called = Arc::new(Mutex::new(false));
3029                    let opened_url = Arc::new(Mutex::new(None));
3030                    let handler = Arc::new(DummyHandler {
3031                        called: called.clone(),
3032                        opened_url: opened_url.clone(),
3033                    });
3034
3035                    conn.set_handler(handler.clone()).await;
3036                    let common = WebsocketCommon::new(
3037                        vec![conn.clone()],
3038                        WebsocketMode::Single,
3039                        0,
3040                        None,
3041                        None,
3042                    );
3043                    let events = subscribe_events(&common);
3044                    common
3045                        .on_open("wss://example.com".into(), conn.clone(), None)
3046                        .await;
3047
3048                    sleep(std::time::Duration::from_millis(10)).await;
3049
3050                    let evs = events.lock().await;
3051                    assert!(evs.iter().any(|e| matches!(e, WebsocketEvent::Open)));
3052                    assert!(*called.lock().await);
3053                    assert_eq!(
3054                        opened_url.lock().await.as_deref(),
3055                        Some("wss://example.com")
3056                    );
3057                });
3058            }
3059
3060            #[test]
3061            fn handles_renewal_pending_and_closes_old_writer() {
3062                TOKIO_SHARED_RT.block_on(async {
3063                    let conn = WebsocketConnection::new("c2");
3064                    let (old_tx, mut old_rx) = unbounded_channel::<Message>();
3065                    {
3066                        let mut st = conn.state.lock().await;
3067                        st.renewal_pending = true;
3068                    }
3069                    let common = WebsocketCommon::new(
3070                        vec![conn.clone()],
3071                        WebsocketMode::Single,
3072                        0,
3073                        None,
3074                        None,
3075                    );
3076                    common
3077                        .on_open("url".into(), conn.clone(), Some(old_tx.clone()))
3078                        .await;
3079                    assert!(!conn.state.lock().await.renewal_pending);
3080                    match old_rx.try_recv() {
3081                        Ok(Message::Close(_)) => {}
3082                        other => panic!("expected Close, got {other:?}"),
3083                    }
3084                });
3085            }
3086        }
3087
3088        mod on_message {
3089            use super::*;
3090
3091            struct DummyHandler {
3092                called_with: Arc<Mutex<Vec<String>>>,
3093            }
3094
3095            #[async_trait]
3096            impl WebsocketHandler for DummyHandler {
3097                async fn on_open(&self, _url: String, _connection: Arc<WebsocketConnection>) {}
3098                async fn on_message(&self, data: String, _connection: Arc<WebsocketConnection>) {
3099                    self.called_with.lock().await.push(data);
3100                }
3101                async fn get_reconnect_url(
3102                    &self,
3103                    default_url: String,
3104                    _connection: Arc<WebsocketConnection>,
3105                ) -> String {
3106                    default_url
3107                }
3108            }
3109
3110            #[test]
3111            fn emits_message_event_without_handler() {
3112                TOKIO_SHARED_RT.block_on(async {
3113                    let conn = WebsocketConnection::new("c1");
3114                    let common = WebsocketCommon::new(
3115                        vec![conn.clone()],
3116                        WebsocketMode::Single,
3117                        0,
3118                        None,
3119                        None,
3120                    );
3121                    let events = subscribe_events(&common);
3122                    common.on_message("msg".into(), conn.clone()).await;
3123
3124                    sleep(Duration::from_millis(10)).await;
3125
3126                    let locked = events.lock().await;
3127                    assert!(
3128                        locked
3129                            .iter()
3130                            .any(|e| matches!(e, WebsocketEvent::Message(m) if m == "msg"))
3131                    );
3132                });
3133            }
3134
3135            #[test]
3136            fn calls_handler_and_emits_message() {
3137                TOKIO_SHARED_RT.block_on(async {
3138                    let conn = WebsocketConnection::new("c2");
3139                    let called = Arc::new(Mutex::new(Vec::new()));
3140                    let handler = Arc::new(DummyHandler {
3141                        called_with: called.clone(),
3142                    });
3143                    conn.set_handler(handler.clone()).await;
3144
3145                    let common = WebsocketCommon::new(
3146                        vec![conn.clone()],
3147                        WebsocketMode::Single,
3148                        0,
3149                        None,
3150                        None,
3151                    );
3152                    let events = subscribe_events(&common);
3153                    common.on_message("msg".into(), conn.clone()).await;
3154
3155                    sleep(Duration::from_millis(10)).await;
3156
3157                    let evs = events.lock().await;
3158                    assert!(
3159                        evs.iter()
3160                            .any(|e| matches!(e, WebsocketEvent::Message(m) if m == "msg"))
3161                    );
3162                    let msgs = called.lock().await;
3163                    assert_eq!(msgs.as_slice(), &["msg".to_string()]);
3164                });
3165            }
3166        }
3167
3168        mod create_websocket {
3169            use super::*;
3170
3171            #[test]
3172            fn successful_connection() {
3173                TOKIO_SHARED_RT.block_on(async {
3174                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3175                    let addr: SocketAddr = listener.local_addr().unwrap();
3176
3177                    let expected_ua = build_user_agent("product");
3178                    let expected_ua_clone = expected_ua.clone();
3179
3180                    tokio::spawn(async move {
3181                        if let Ok((stream, _)) = listener.accept().await {
3182                            let callback = |req: &Request, resp| {
3183                                let got = req
3184                                    .headers()
3185                                    .get(USER_AGENT)
3186                                    .expect("no USER_AGENT header in WS handshake")
3187                                    .to_str()
3188                                    .expect("invalid USER_AGENT header");
3189                                assert_eq!(got, expected_ua_clone, "User-Agent mismatch");
3190                                Ok(resp)
3191                            };
3192                            let _ = accept_hdr_async(stream, callback).await.unwrap();
3193                        }
3194                    });
3195
3196                    let url = format!("ws://{addr}");
3197                    let res =
3198                        WebsocketCommon::create_websocket(&url, None, Some(expected_ua)).await;
3199                    assert!(res.is_ok(), "handshake failed: {res:?}");
3200                });
3201            }
3202
3203            #[test]
3204            fn invalid_url_returns_handshake_error() {
3205                TOKIO_SHARED_RT.block_on(async {
3206                    let res =
3207                        WebsocketCommon::create_websocket("not-a-valid-url", None, None).await;
3208                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3209                });
3210            }
3211
3212            #[test]
3213            fn unreachable_host_returns_handshake_error() {
3214                TOKIO_SHARED_RT.block_on(async {
3215                    let res =
3216                        WebsocketCommon::create_websocket("ws://127.0.0.1:1", None, None).await;
3217                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3218                });
3219            }
3220        }
3221
3222        mod connect_pool {
3223            use super::*;
3224
3225            #[test]
3226            fn connects_all_in_pool() {
3227                TOKIO_SHARED_RT.block_on(async {
3228                    let pool_size = 3;
3229                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3230                    let addr = listener.local_addr().unwrap();
3231                    tokio::spawn(async move {
3232                        for _ in 0..pool_size {
3233                            if let Ok((stream, _)) = listener.accept().await {
3234                                let mut ws = accept_async(stream).await.unwrap();
3235                                let _ = ws.close(None).await;
3236                            }
3237                        }
3238                    });
3239                    let conns: Vec<Arc<WebsocketConnection>> = (0..pool_size)
3240                        .map(|i| WebsocketConnection::new(format!("c{i}")))
3241                        .collect();
3242                    let common = WebsocketCommon::new(
3243                        conns.clone(),
3244                        WebsocketMode::Pool(pool_size),
3245                        0,
3246                        None,
3247                        None,
3248                    );
3249                    let url = format!("ws://{addr}");
3250                    common.clone().connect_pool(&url).await.unwrap();
3251                    for conn in conns {
3252                        let st = conn.state.lock().await;
3253                        assert!(st.ws_write_tx.is_some());
3254                    }
3255                });
3256            }
3257
3258            #[test]
3259            fn fails_if_any_refused() {
3260                TOKIO_SHARED_RT.block_on(async {
3261                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3262                    let addr = listener.local_addr().unwrap();
3263                    let pool_size = 3;
3264                    tokio::spawn(async move {
3265                        for _ in 0..2 {
3266                            if let Ok((stream, _)) = listener.accept().await {
3267                                let mut ws = accept_async(stream).await.unwrap();
3268                                let _ = ws.close(None).await;
3269                            }
3270                        }
3271                    });
3272                    let mut conns = Vec::new();
3273                    let valid_url = format!("ws://{addr}");
3274                    for i in 0..2 {
3275                        conns.push(WebsocketConnection::new(format!("c{i}")));
3276                    }
3277                    conns.push(WebsocketConnection::new("bad"));
3278                    let common = WebsocketCommon::new(
3279                        conns.clone(),
3280                        WebsocketMode::Pool(pool_size),
3281                        0,
3282                        None,
3283                        None,
3284                    );
3285                    let res = common.clone().connect_pool(&valid_url).await;
3286                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3287                });
3288            }
3289
3290            #[test]
3291            fn fails_on_invalid_url() {
3292                TOKIO_SHARED_RT.block_on(async {
3293                    let conns = vec![WebsocketConnection::new("c1")];
3294                    let common = WebsocketCommon::new(conns, WebsocketMode::Pool(1), 0, None, None);
3295                    let res = common.connect_pool("not-a-url").await;
3296                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3297                });
3298            }
3299
3300            #[test]
3301            fn fails_if_mixed_success_and_invalid_url() {
3302                TOKIO_SHARED_RT.block_on(async {
3303                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3304                    let addr = listener.local_addr().unwrap();
3305                    tokio::spawn(async move {
3306                        if let Ok((stream, _)) = listener.accept().await {
3307                            let mut ws = accept_async(stream).await.unwrap();
3308                            let _ = ws.close(None).await;
3309                        }
3310                    });
3311                    let good = WebsocketConnection::new("good");
3312                    let bad = WebsocketConnection::new("bad");
3313                    let common = WebsocketCommon::new(
3314                        vec![good, bad],
3315                        WebsocketMode::Pool(2),
3316                        0,
3317                        None,
3318                        None,
3319                    );
3320                    let url = format!("ws://{addr}");
3321                    let res = common.connect_pool(&url).await;
3322                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3323                });
3324            }
3325
3326            #[test]
3327            fn init_connect_invoked_for_each() {
3328                TOKIO_SHARED_RT.block_on(async {
3329                    let pool_size = 2;
3330                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3331                    let addr = listener.local_addr().unwrap();
3332                    tokio::spawn(async move {
3333                        for _ in 0..pool_size {
3334                            if let Ok((stream, _)) = listener.accept().await {
3335                                let mut ws = accept_async(stream).await.unwrap();
3336                                let _ = ws.close(None).await;
3337                            }
3338                        }
3339                    });
3340                    let conns: Vec<Arc<WebsocketConnection>> = (0..pool_size)
3341                        .map(|i| WebsocketConnection::new(format!("c{i}")))
3342                        .collect();
3343                    let common = WebsocketCommon::new(
3344                        conns.clone(),
3345                        WebsocketMode::Pool(pool_size),
3346                        0,
3347                        None,
3348                        None,
3349                    );
3350                    let url = format!("ws://{addr}");
3351                    common.clone().connect_pool(&url).await.unwrap();
3352                    for conn in conns {
3353                        let st = conn.state.lock().await;
3354                        assert!(st.ws_write_tx.is_some());
3355                    }
3356                });
3357            }
3358
3359            #[test]
3360            fn single_mode_uses_first_connection() {
3361                TOKIO_SHARED_RT.block_on(async {
3362                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3363                    let addr = listener.local_addr().unwrap();
3364                    tokio::spawn(async move {
3365                        if let Ok((stream, _)) = listener.accept().await {
3366                            let mut ws = accept_async(stream).await.unwrap();
3367                            let _ = ws.close(None).await;
3368                        }
3369                    });
3370                    let conn = WebsocketConnection::new("c1");
3371                    let common = WebsocketCommon::new(
3372                        vec![conn.clone()],
3373                        WebsocketMode::Single,
3374                        0,
3375                        None,
3376                        None,
3377                    );
3378                    let url = format!("ws://{addr}");
3379                    common.connect_pool(&url).await.unwrap();
3380                    let st = conn.state.lock().await;
3381                    assert!(st.ws_write_tx.is_some());
3382                });
3383            }
3384        }
3385
3386        mod init_connect {
3387            use super::*;
3388
3389            #[test]
3390            fn pool_mode_none_connection_uses_first() {
3391                TOKIO_SHARED_RT.block_on(async {
3392                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3393                    let addr = listener.local_addr().unwrap();
3394                    tokio::spawn(async move {
3395                        for _ in 0..2 {
3396                            if let Ok((stream, _)) = listener.accept().await {
3397                                let mut ws = accept_async(stream).await.unwrap();
3398                                ws.close(None).await.ok();
3399                            }
3400                        }
3401                    });
3402
3403                    let c1 = WebsocketConnection::new("c1");
3404                    let c2 = WebsocketConnection::new("c2");
3405                    let common = WebsocketCommon::new(
3406                        vec![c1.clone(), c2.clone()],
3407                        WebsocketMode::Pool(2),
3408                        0,
3409                        None,
3410                        None,
3411                    );
3412                    let url = format!("ws://{addr}");
3413
3414                    common
3415                        .clone()
3416                        .init_connect(&url, false, None)
3417                        .await
3418                        .unwrap();
3419                    let st1 = c1.state.lock().await;
3420                    let st2 = c2.state.lock().await;
3421
3422                    assert!(st1.ws_write_tx.is_some());
3423                    assert!(st2.ws_write_tx.is_none());
3424                });
3425            }
3426
3427            #[test]
3428            fn writer_channel_can_send_text() {
3429                TOKIO_SHARED_RT.block_on(async {
3430                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3431                    let addr = listener.local_addr().unwrap();
3432                    let received = Arc::new(Mutex::new(None::<String>));
3433                    let received_clone = received.clone();
3434
3435                    tokio::spawn(async move {
3436                        if let Ok((stream, _)) = listener.accept().await {
3437                            let mut ws = accept_async(stream).await.unwrap();
3438                            if let Some(Ok(Message::Text(txt))) = ws.next().await {
3439                                *received_clone.lock().await = Some(txt.to_string());
3440                            }
3441                            ws.close(None).await.ok();
3442                        }
3443                    });
3444
3445                    let conn = WebsocketConnection::new("cw");
3446                    let common = WebsocketCommon::new(
3447                        vec![conn.clone()],
3448                        WebsocketMode::Single,
3449                        0,
3450                        None,
3451                        None,
3452                    );
3453                    let url = format!("ws://{addr}");
3454                    common
3455                        .clone()
3456                        .init_connect(&url, false, Some(conn.clone()))
3457                        .await
3458                        .unwrap();
3459
3460                    let tx = conn.state.lock().await.ws_write_tx.clone().unwrap();
3461                    tx.send(Message::Text("ping".into())).unwrap();
3462
3463                    sleep(Duration::from_millis(50)).await;
3464
3465                    let lock = received.lock().await;
3466                    assert_eq!(lock.as_deref(), Some("ping"));
3467                });
3468            }
3469
3470            #[test]
3471            fn responds_to_ping_with_pong() {
3472                TOKIO_SHARED_RT.block_on(async {
3473                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3474                    let addr = listener.local_addr().unwrap();
3475
3476                    let saw_pong = Arc::new(Mutex::new(false));
3477                    let saw_pong2 = saw_pong.clone();
3478
3479                    tokio::spawn(async move {
3480                        if let Ok((stream, _)) = listener.accept().await {
3481                            let mut ws = accept_async(stream).await.unwrap();
3482                            ws.send(Message::Ping(vec![1, 2, 3].into())).await.unwrap();
3483                            if let Some(Ok(Message::Pong(payload))) = ws.next().await {
3484                                if payload[..] == [1, 2, 3] {
3485                                    *saw_pong2.lock().await = true;
3486                                }
3487                            }
3488                            let _ = ws.close(None).await;
3489                        }
3490                    });
3491
3492                    let conn = WebsocketConnection::new("c-ping");
3493                    let common = WebsocketCommon::new(
3494                        vec![conn.clone()],
3495                        WebsocketMode::Single,
3496                        0,
3497                        None,
3498                        None,
3499                    );
3500                    let url = format!("ws://{addr}");
3501                    common
3502                        .clone()
3503                        .init_connect(&url, false, Some(conn))
3504                        .await
3505                        .unwrap();
3506
3507                    sleep(Duration::from_millis(50)).await;
3508
3509                    assert!(*saw_pong.lock().await, "server should have seen a Pong");
3510                });
3511            }
3512
3513            #[test]
3514            fn handshake_error_on_invalid_url() {
3515                TOKIO_SHARED_RT.block_on(async {
3516                    let conn = WebsocketConnection::new("c-invalid");
3517                    let common = WebsocketCommon::new(
3518                        vec![conn.clone()],
3519                        WebsocketMode::Single,
3520                        0,
3521                        None,
3522                        None,
3523                    );
3524                    let res = common
3525                        .clone()
3526                        .init_connect("not-a-url", false, Some(conn.clone()))
3527                        .await;
3528                    assert!(matches!(res, Err(WebsocketError::Handshake(_))));
3529                });
3530            }
3531
3532            #[test]
3533            fn skip_if_writer_exists_and_not_renewal() {
3534                TOKIO_SHARED_RT.block_on(async {
3535                    let conn = WebsocketConnection::new("c-writer");
3536                    let (tx, mut rx) = unbounded_channel::<Message>();
3537                    {
3538                        let mut st = conn.state.lock().await;
3539                        st.ws_write_tx = Some(tx.clone());
3540                    }
3541                    let common = WebsocketCommon::new(
3542                        vec![conn.clone()],
3543                        WebsocketMode::Single,
3544                        0,
3545                        None,
3546                        None,
3547                    );
3548                    let res = common
3549                        .clone()
3550                        .init_connect("ws://127.0.0.1:1", false, Some(conn.clone()))
3551                        .await;
3552
3553                    assert!(res.is_ok());
3554                    assert!(rx.try_recv().is_err());
3555                });
3556            }
3557
3558            #[test]
3559            fn short_circuit_on_already_renewing() {
3560                TOKIO_SHARED_RT.block_on(async {
3561                    let conn = WebsocketConnection::new("c-renew");
3562                    {
3563                        let mut st = conn.state.lock().await;
3564                        st.renewal_pending = true;
3565                    }
3566                    let common = WebsocketCommon::new(
3567                        vec![conn.clone()],
3568                        WebsocketMode::Single,
3569                        0,
3570                        None,
3571                        None,
3572                    );
3573                    let res = common
3574                        .clone()
3575                        .init_connect("ws://127.0.0.1:1", true, Some(conn.clone()))
3576                        .await;
3577
3578                    assert!(res.is_ok());
3579                    assert!(conn.state.lock().await.ws_write_tx.is_none());
3580                });
3581            }
3582
3583            #[test]
3584            fn is_renewal_true_sets_and_clears_flag() {
3585                TOKIO_SHARED_RT.block_on(async {
3586                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3587                    let addr = listener.local_addr().unwrap();
3588                    tokio::spawn(async move {
3589                        if let Ok((stream, _)) = listener.accept().await {
3590                            let mut ws = accept_async(stream).await.unwrap();
3591                            let _ = ws.close(None).await;
3592                        }
3593                    });
3594
3595                    let conn = WebsocketConnection::new("c-new-renew");
3596                    let common = WebsocketCommon::new(
3597                        vec![conn.clone()],
3598                        WebsocketMode::Single,
3599                        0,
3600                        None,
3601                        None,
3602                    );
3603                    let url = format!("ws://{addr}");
3604                    let res = common
3605                        .clone()
3606                        .init_connect(&url, true, Some(conn.clone()))
3607                        .await;
3608
3609                    assert!(res.is_ok());
3610                    let st = conn.state.lock().await;
3611                    assert!(st.ws_write_tx.is_some());
3612                    assert!(!st.renewal_pending);
3613                });
3614            }
3615
3616            #[test]
3617            fn default_connection_selected_when_none_passed() {
3618                TOKIO_SHARED_RT.block_on(async {
3619                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3620                    let addr = listener.local_addr().unwrap();
3621                    tokio::spawn(async move {
3622                        if let Ok((stream, _)) = listener.accept().await {
3623                            let mut ws = accept_async(stream).await.unwrap();
3624                            let _ = ws.close(None).await;
3625                        }
3626                    });
3627                    let conn = WebsocketConnection::new("c-default");
3628                    let common = WebsocketCommon::new(
3629                        vec![conn.clone()],
3630                        WebsocketMode::Single,
3631                        0,
3632                        None,
3633                        None,
3634                    );
3635                    let url = format!("ws://{addr}");
3636                    let res = common.clone().init_connect(&url, false, None).await;
3637
3638                    assert!(res.is_ok());
3639                    assert!(conn.state.lock().await.ws_write_tx.is_some());
3640                });
3641            }
3642
3643            #[test]
3644            fn schedules_reconnect_on_abnormal_close() {
3645                TOKIO_SHARED_RT.block_on(async {
3646                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3647                    let addr = listener.local_addr().unwrap();
3648                    tokio::spawn(async move {
3649                        if let Ok((stream, _)) = listener.accept().await {
3650                            let mut ws = accept_async(stream).await.unwrap();
3651                            ws.close(Some(tungstenite::protocol::CloseFrame {
3652                                code: tungstenite::protocol::frame::coding::CloseCode::Abnormal,
3653                                reason: "oops".into(),
3654                            }))
3655                            .await
3656                            .ok();
3657                        }
3658                    });
3659                    let conn = WebsocketConnection::new("c-close");
3660                    let common = WebsocketCommon::new(
3661                        vec![conn.clone()],
3662                        WebsocketMode::Single,
3663                        10,
3664                        None,
3665                        None,
3666                    );
3667                    let url = format!("ws://{addr}");
3668                    common
3669                        .clone()
3670                        .init_connect(&url, false, Some(conn.clone()))
3671                        .await
3672                        .unwrap();
3673
3674                    sleep(Duration::from_millis(50)).await;
3675
3676                    let st = conn.state.lock().await;
3677                    assert!(
3678                        st.reconnection_pending,
3679                        "expected reconnection_pending to be true after abnormal close"
3680                    );
3681                });
3682            }
3683        }
3684
3685        mod disconnect {
3686            use super::*;
3687
3688            #[test]
3689            fn returns_ok_when_no_connections_are_ready() {
3690                TOKIO_SHARED_RT.block_on(async {
3691                    let conn = WebsocketConnection::new("c1");
3692                    let common = WebsocketCommon::new(
3693                        vec![conn.clone()],
3694                        WebsocketMode::Single,
3695                        0,
3696                        None,
3697                        None,
3698                    );
3699                    let res = common.disconnect().await;
3700
3701                    assert!(res.is_ok());
3702                    assert!(!conn.state.lock().await.close_initiated);
3703                });
3704            }
3705
3706            #[test]
3707            fn closes_all_ready_connections() {
3708                TOKIO_SHARED_RT.block_on(async {
3709                    let conn1 = WebsocketConnection::new("c1");
3710                    let conn2 = WebsocketConnection::new("c2");
3711                    let (tx1, mut rx1) = unbounded_channel::<Message>();
3712                    let (tx2, mut rx2) = unbounded_channel::<Message>();
3713                    {
3714                        let mut s1 = conn1.state.lock().await;
3715                        s1.ws_write_tx = Some(tx1);
3716                    }
3717                    {
3718                        let mut s2 = conn2.state.lock().await;
3719                        s2.ws_write_tx = Some(tx2);
3720                    }
3721                    let common = WebsocketCommon::new(
3722                        vec![conn1.clone(), conn2.clone()],
3723                        WebsocketMode::Pool(2),
3724                        0,
3725                        None,
3726                        None,
3727                    );
3728                    let fut = common.disconnect();
3729
3730                    sleep(Duration::from_millis(50)).await;
3731
3732                    fut.await.unwrap();
3733
3734                    assert!(conn1.state.lock().await.close_initiated);
3735                    assert!(conn2.state.lock().await.close_initiated);
3736
3737                    match (rx1.try_recv(), rx2.try_recv()) {
3738                        (Ok(Message::Close(_)), Ok(Message::Close(_))) => {}
3739                        other => panic!("expected two Closes, got {other:?}"),
3740                    }
3741                });
3742            }
3743
3744            #[test]
3745            fn does_not_mark_close_initiated_if_no_writer() {
3746                TOKIO_SHARED_RT.block_on(async {
3747                    let conn = WebsocketConnection::new("c-new");
3748                    let common = WebsocketCommon::new(
3749                        vec![conn.clone()],
3750                        WebsocketMode::Single,
3751                        0,
3752                        None,
3753                        None,
3754                    );
3755                    common.disconnect().await.unwrap();
3756
3757                    assert!(!conn.state.lock().await.close_initiated);
3758                });
3759            }
3760
3761            #[test]
3762            fn mixed_pool_marks_all_and_closes_only_writers() {
3763                TOKIO_SHARED_RT.block_on(async {
3764                    let conn_w = WebsocketConnection::new("with");
3765                    let conn_wo = WebsocketConnection::new("without");
3766                    let (tx, mut rx) = unbounded_channel::<Message>();
3767                    {
3768                        let mut st = conn_w.state.lock().await;
3769                        st.ws_write_tx = Some(tx);
3770                    }
3771                    let common = WebsocketCommon::new(
3772                        vec![conn_w.clone(), conn_wo.clone()],
3773                        WebsocketMode::Pool(2),
3774                        0,
3775                        None,
3776                        None,
3777                    );
3778                    let fut = common.disconnect();
3779
3780                    sleep(Duration::from_millis(50)).await;
3781
3782                    fut.await.unwrap();
3783
3784                    assert!(conn_w.state.lock().await.close_initiated);
3785                    assert!(conn_wo.state.lock().await.close_initiated);
3786                    assert!(matches!(rx.try_recv(), Ok(Message::Close(_))));
3787                });
3788            }
3789
3790            #[test]
3791            fn after_disconnect_not_connected() {
3792                TOKIO_SHARED_RT.block_on(async {
3793                    let conn = WebsocketConnection::new("c1");
3794                    let (tx, mut _rx) = unbounded_channel::<Message>();
3795                    {
3796                        let mut st = conn.state.lock().await;
3797                        st.ws_write_tx = Some(tx);
3798                    }
3799                    let common = WebsocketCommon::new(
3800                        vec![conn.clone()],
3801                        WebsocketMode::Single,
3802                        0,
3803                        None,
3804                        None,
3805                    );
3806                    common.disconnect().await.unwrap();
3807                    assert!(!common.is_connected(Some(&conn)).await);
3808                });
3809            }
3810        }
3811
3812        mod ping_server {
3813            use super::*;
3814
3815            #[test]
3816            fn sends_ping_to_all_ready_connections() {
3817                TOKIO_SHARED_RT.block_on(async {
3818                    let mut conns = Vec::new();
3819                    for i in 0..3 {
3820                        let conn = WebsocketConnection::new(format!("c{i}"));
3821                        let (tx, rx) = unbounded_channel::<Message>();
3822                        {
3823                            let mut st = conn.state.lock().await;
3824                            st.ws_write_tx = Some(tx);
3825                        }
3826                        conns.push((conn, rx));
3827                    }
3828                    let common = WebsocketCommon::new(
3829                        conns.iter().map(|(c, _)| c.clone()).collect(),
3830                        WebsocketMode::Pool(3),
3831                        0,
3832                        None,
3833                        None,
3834                    );
3835                    common.ping_server().await;
3836                    for (_, mut rx) in conns {
3837                        match rx.try_recv() {
3838                            Ok(Message::Ping(payload)) if payload.is_empty() => {}
3839                            other => panic!("expected empty-payload Ping, got {other:?}"),
3840                        }
3841                    }
3842                });
3843            }
3844
3845            #[test]
3846            fn skips_not_ready_and_partial() {
3847                TOKIO_SHARED_RT.block_on(async {
3848                    let ready = WebsocketConnection::new("ready");
3849                    let not_ready = WebsocketConnection::new("not-ready");
3850                    let (tx_r, mut rx_r) = unbounded_channel::<Message>();
3851                    {
3852                        let mut st = ready.state.lock().await;
3853                        st.ws_write_tx = Some(tx_r);
3854                    }
3855                    {
3856                        let mut st = not_ready.state.lock().await;
3857                        st.ws_write_tx = None;
3858                    }
3859                    let common = WebsocketCommon::new(
3860                        vec![ready.clone(), not_ready.clone()],
3861                        WebsocketMode::Pool(2),
3862                        0,
3863                        None,
3864                        None,
3865                    );
3866                    common.ping_server().await;
3867                    match rx_r.try_recv() {
3868                        Ok(Message::Ping(payload)) if payload.is_empty() => {}
3869                        other => panic!("expected Ping on ready, got {other:?}"),
3870                    }
3871                });
3872            }
3873
3874            #[test]
3875            fn no_ping_when_flags_block() {
3876                TOKIO_SHARED_RT.block_on(async {
3877                    let conn = WebsocketConnection::new("c1");
3878                    let (tx, mut rx) = unbounded_channel::<Message>();
3879                    {
3880                        let mut st = conn.state.lock().await;
3881                        st.ws_write_tx = Some(tx);
3882                        st.reconnection_pending = true;
3883                    }
3884                    let common = WebsocketCommon::new(
3885                        vec![conn.clone()],
3886                        WebsocketMode::Single,
3887                        0,
3888                        None,
3889                        None,
3890                    );
3891                    common.ping_server().await;
3892                    assert!(rx.try_recv().is_err());
3893                });
3894            }
3895        }
3896
3897        mod send {
3898            use super::*;
3899
3900            #[test]
3901            fn round_robin_send_without_specific() {
3902                TOKIO_SHARED_RT.block_on(async {
3903                    let conn1 = WebsocketConnection::new("c1");
3904                    let conn2 = WebsocketConnection::new("c2");
3905                    let (tx1, mut rx1) = unbounded_channel::<Message>();
3906                    let (tx2, mut rx2) = unbounded_channel::<Message>();
3907                    {
3908                        let mut s1 = conn1.state.lock().await;
3909                        s1.ws_write_tx = Some(tx1);
3910                    }
3911                    {
3912                        let mut s2 = conn2.state.lock().await;
3913                        s2.ws_write_tx = Some(tx2);
3914                    }
3915                    let common = WebsocketCommon::new(
3916                        vec![conn1.clone(), conn2.clone()],
3917                        WebsocketMode::Pool(2),
3918                        0,
3919                        None,
3920                        None,
3921                    );
3922
3923                    let res1 = common
3924                        .send("a".into(), None, false, Duration::from_secs(1), None)
3925                        .await
3926                        .unwrap();
3927                    assert!(res1.is_none());
3928
3929                    let res2 = common
3930                        .send("b".into(), None, false, Duration::from_secs(1), None)
3931                        .await
3932                        .unwrap();
3933                    assert!(res2.is_none());
3934
3935                    assert_eq!(
3936                        if let Message::Text(t) = rx1.try_recv().unwrap() {
3937                            t
3938                        } else {
3939                            panic!()
3940                        },
3941                        "a"
3942                    );
3943                    assert_eq!(
3944                        if let Message::Text(t) = rx2.try_recv().unwrap() {
3945                            t
3946                        } else {
3947                            panic!()
3948                        },
3949                        "b"
3950                    );
3951                });
3952            }
3953
3954            #[test]
3955            fn round_robin_skips_not_ready() {
3956                TOKIO_SHARED_RT.block_on(async {
3957                    let conn1 = WebsocketConnection::new("c1");
3958                    let conn2 = WebsocketConnection::new("c2");
3959                    let (tx2, mut rx2) = unbounded_channel::<Message>();
3960                    {
3961                        let mut s1 = conn1.state.lock().await;
3962                        s1.ws_write_tx = None;
3963                    }
3964                    {
3965                        let mut s2 = conn2.state.lock().await;
3966                        s2.ws_write_tx = Some(tx2);
3967                    }
3968                    let common = WebsocketCommon::new(
3969                        vec![conn1.clone(), conn2.clone()],
3970                        WebsocketMode::Pool(2),
3971                        0,
3972                        None,
3973                        None,
3974                    );
3975                    let res = common
3976                        .send("bar".into(), None, false, Duration::from_secs(1), None)
3977                        .await
3978                        .unwrap();
3979                    assert!(res.is_none());
3980                    match rx2.try_recv().unwrap() {
3981                        Message::Text(t) => assert_eq!(t, "bar"),
3982                        other => panic!("unexpected {other:?}"),
3983                    }
3984                });
3985            }
3986
3987            #[test]
3988            fn sync_send_on_specific_connection() {
3989                TOKIO_SHARED_RT.block_on(async {
3990                    let conn1 = WebsocketConnection::new("c1");
3991                    let conn2 = WebsocketConnection::new("c2");
3992                    let (tx2, mut rx2) = unbounded_channel::<Message>();
3993                    {
3994                        let mut st = conn2.state.lock().await;
3995                        st.ws_write_tx = Some(tx2);
3996                    }
3997                    let common = WebsocketCommon::new(
3998                        vec![conn1.clone(), conn2.clone()],
3999                        WebsocketMode::Pool(2),
4000                        0,
4001                        None,
4002                        None,
4003                    );
4004                    let res = common
4005                        .send(
4006                            "payload".into(),
4007                            Some("id".into()),
4008                            false,
4009                            Duration::from_secs(1),
4010                            Some(conn2.clone()),
4011                        )
4012                        .await
4013                        .unwrap();
4014                    assert!(res.is_none());
4015                    match rx2.try_recv() {
4016                        Ok(Message::Text(t)) => assert_eq!(t, "payload"),
4017                        other => panic!("expected Text, got {other:?}"),
4018                    }
4019                });
4020            }
4021
4022            #[test]
4023            fn sync_send_with_id_does_not_insert_pending() {
4024                TOKIO_SHARED_RT.block_on(async {
4025                    let conn = WebsocketConnection::new("c1");
4026                    let (tx, mut rx) = unbounded_channel::<Message>();
4027                    {
4028                        let mut st = conn.state.lock().await;
4029                        st.ws_write_tx = Some(tx);
4030                    }
4031                    let common = WebsocketCommon::new(
4032                        vec![conn.clone()],
4033                        WebsocketMode::Single,
4034                        0,
4035                        None,
4036                        None,
4037                    );
4038                    let res = common
4039                        .send(
4040                            "msg".into(),
4041                            Some("id".into()),
4042                            false,
4043                            Duration::from_secs(1),
4044                            Some(conn.clone()),
4045                        )
4046                        .await
4047                        .unwrap();
4048                    assert!(res.is_none());
4049                    assert!(conn.state.lock().await.pending_requests.is_empty());
4050                    match rx.try_recv().unwrap() {
4051                        Message::Text(t) => assert_eq!(t, "msg"),
4052                        other => panic!("unexpected {other:?}"),
4053                    }
4054                });
4055            }
4056
4057            #[test]
4058            fn sync_send_error_if_not_ready() {
4059                TOKIO_SHARED_RT.block_on(async {
4060                    let conn = WebsocketConnection::new("c1");
4061                    let common = WebsocketCommon::new(
4062                        vec![conn.clone()],
4063                        WebsocketMode::Single,
4064                        0,
4065                        None,
4066                        None,
4067                    );
4068                    let err = common
4069                        .send(
4070                            "msg".into(),
4071                            Some("id".into()),
4072                            false,
4073                            Duration::from_secs(1),
4074                            Some(conn.clone()),
4075                        )
4076                        .await
4077                        .unwrap_err();
4078                    assert!(matches!(err, WebsocketError::NotConnected));
4079                });
4080            }
4081
4082            #[test]
4083            fn sync_send_error_when_no_ready() {
4084                TOKIO_SHARED_RT.block_on(async {
4085                    let conn = WebsocketConnection::new("c1");
4086                    let common = WebsocketCommon::new(
4087                        vec![conn.clone()],
4088                        WebsocketMode::Single,
4089                        0,
4090                        None,
4091                        None,
4092                    );
4093                    let err = common
4094                        .send("msg".into(), None, false, Duration::from_secs(1), None)
4095                        .await
4096                        .unwrap_err();
4097                    assert!(matches!(err, WebsocketError::NotConnected));
4098                });
4099            }
4100
4101            #[test]
4102            fn async_send_and_receive() {
4103                TOKIO_SHARED_RT.block_on(async {
4104                    let conn = WebsocketConnection::new("c1");
4105                    let (tx, mut rx) = unbounded_channel::<Message>();
4106                    {
4107                        let mut st = conn.state.lock().await;
4108                        st.ws_write_tx = Some(tx);
4109                    }
4110                    let common = WebsocketCommon::new(
4111                        vec![conn.clone()],
4112                        WebsocketMode::Single,
4113                        0,
4114                        None,
4115                        None,
4116                    );
4117                    let fut = common
4118                        .send(
4119                            "hello".into(),
4120                            Some("id".into()),
4121                            true,
4122                            Duration::from_secs(5),
4123                            Some(conn.clone()),
4124                        )
4125                        .await
4126                        .unwrap()
4127                        .unwrap();
4128                    match rx.try_recv() {
4129                        Ok(Message::Text(t)) => assert_eq!(t, "hello"),
4130                        other => panic!("expected Text, got {other:?}"),
4131                    }
4132                    {
4133                        let mut st = conn.state.lock().await;
4134                        let pr = st.pending_requests.remove("id").unwrap();
4135                        pr.completion.send(Ok(serde_json::json!("ok"))).unwrap();
4136                    }
4137                    let resp = fut.await.unwrap().unwrap();
4138                    assert_eq!(resp, serde_json::json!("ok"));
4139                });
4140            }
4141
4142            #[test]
4143            fn async_send_default_connection() {
4144                TOKIO_SHARED_RT.block_on(async {
4145                    let conn = WebsocketConnection::new("c1");
4146                    let (tx, mut rx) = unbounded_channel::<Message>();
4147                    {
4148                        let mut st = conn.state.lock().await;
4149                        st.ws_write_tx = Some(tx);
4150                    }
4151                    let common = WebsocketCommon::new(
4152                        vec![conn.clone()],
4153                        WebsocketMode::Single,
4154                        0,
4155                        None,
4156                        None,
4157                    );
4158                    let fut = common
4159                        .send(
4160                            "msg".into(),
4161                            Some("id".into()),
4162                            true,
4163                            Duration::from_secs(5),
4164                            None,
4165                        )
4166                        .await
4167                        .unwrap()
4168                        .unwrap();
4169                    match rx.try_recv() {
4170                        Ok(Message::Text(t)) => assert_eq!(t, "msg"),
4171                        _ => panic!("no text"),
4172                    }
4173                    {
4174                        let mut st = conn.state.lock().await;
4175                        let pr = st.pending_requests.remove("id").unwrap();
4176                        pr.completion.send(Ok(serde_json::json!(123))).unwrap();
4177                    }
4178                    let resp = fut.await.unwrap().unwrap();
4179                    assert_eq!(resp, serde_json::json!(123));
4180                });
4181            }
4182
4183            #[test]
4184            fn async_send_error_if_no_id() {
4185                TOKIO_SHARED_RT.block_on(async {
4186                    let conn = WebsocketConnection::new("c§");
4187                    let (tx, _rx) = unbounded_channel::<Message>();
4188                    {
4189                        let mut st = conn.state.lock().await;
4190                        st.ws_write_tx = Some(tx);
4191                    }
4192                    let common = WebsocketCommon::new(
4193                        vec![conn.clone()],
4194                        WebsocketMode::Single,
4195                        0,
4196                        None,
4197                        None,
4198                    );
4199                    let err = common
4200                        .send(
4201                            "msg".into(),
4202                            None,
4203                            true,
4204                            Duration::from_secs(1),
4205                            Some(conn.clone()),
4206                        )
4207                        .await
4208                        .unwrap_err();
4209                    assert!(matches!(err, WebsocketError::NotConnected));
4210                });
4211            }
4212
4213            #[test]
4214            fn timeout_rejects_async() {
4215                TOKIO_SHARED_RT.block_on(async {
4216                    pause();
4217                    let conn = WebsocketConnection::new("c1");
4218                    let (tx, _rx) = unbounded_channel::<Message>();
4219                    {
4220                        let mut st = conn.state.lock().await;
4221                        st.ws_write_tx = Some(tx);
4222                    }
4223                    let common = WebsocketCommon::new(
4224                        vec![conn.clone()],
4225                        WebsocketMode::Single,
4226                        0,
4227                        None,
4228                        None,
4229                    );
4230                    let fut = common
4231                        .send(
4232                            "msg".into(),
4233                            Some("id".into()),
4234                            true,
4235                            Duration::from_secs(1),
4236                            Some(conn.clone()),
4237                        )
4238                        .await
4239                        .unwrap()
4240                        .unwrap();
4241                    advance(Duration::from_secs(1)).await;
4242                    let res = fut.await.unwrap();
4243                    assert!(res.is_err(), "expected timeout error");
4244                    assert!(!conn.state.lock().await.pending_requests.contains_key("id"));
4245                });
4246            }
4247
4248            #[test]
4249            fn async_send_errors_if_no_connection_ready() {
4250                TOKIO_SHARED_RT.block_on(async {
4251                    let conn = WebsocketConnection::new("c1");
4252                    let common = WebsocketCommon::new(
4253                        vec![conn.clone()],
4254                        WebsocketMode::Single,
4255                        0,
4256                        None,
4257                        None,
4258                    );
4259                    let err = common
4260                        .send(
4261                            "msg".into(),
4262                            Some("id".into()),
4263                            true,
4264                            Duration::from_secs(1),
4265                            None,
4266                        )
4267                        .await
4268                        .unwrap_err();
4269                    assert!(matches!(err, WebsocketError::NotConnected));
4270                });
4271            }
4272        }
4273    }
4274
4275    mod websocket_api {
4276        use super::*;
4277
4278        mod initialisation {
4279            use super::*;
4280
4281            #[test]
4282            fn new_initializes_common() {
4283                TOKIO_SHARED_RT.block_on(async {
4284                    let conn = WebsocketConnection::new("id");
4285                    let pool = vec![conn.clone()];
4286
4287                    let sig_gen = SignatureGenerator::new(
4288                        Some("api_secret".to_string()),
4289                        None::<PrivateKey>,
4290                        None::<String>,
4291                    );
4292
4293                    let config = ConfigurationWebsocketApi {
4294                        api_key: Some("api_key".to_string()),
4295                        api_secret: Some("api_secret".to_string()),
4296                        private_key: None,
4297                        private_key_passphrase: None,
4298                        ws_url: Some("wss://example".to_string()),
4299                        mode: WebsocketMode::Single,
4300                        reconnect_delay: 1000,
4301                        signature_gen: sig_gen,
4302                        timeout: 500,
4303                        time_unit: None,
4304                        agent: None,
4305                        user_agent: build_user_agent("product"),
4306                    };
4307
4308                    let api = WebsocketApi::new(config, pool.clone());
4309
4310                    assert_eq!(api.common.connection_pool.len(), 1);
4311                    assert_eq!(api.common.mode, WebsocketMode::Single);
4312
4313                    let flag = *api.is_connecting.lock().await;
4314                    assert!(!flag);
4315                });
4316            }
4317        }
4318
4319        mod connect {
4320            use super::*;
4321
4322            #[test]
4323            fn connect_when_not_connected_establishes() {
4324                TOKIO_SHARED_RT.block_on(async {
4325                    let conn = WebsocketConnection::new("id");
4326                    {
4327                        let mut st = conn.state.lock().await;
4328                        st.ws_write_tx = None;
4329                    }
4330                    let sig = SignatureGenerator::new(
4331                        Some("api_secret".into()),
4332                        None::<PrivateKey>,
4333                        None::<String>,
4334                    );
4335                    let cfg = ConfigurationWebsocketApi {
4336                        api_key: Some("api_key".into()),
4337                        api_secret: Some("api_secret".to_string()),
4338                        private_key: None,
4339                        private_key_passphrase: None,
4340                        ws_url: Some("ws://doesnotexist:1".to_string()),
4341                        mode: WebsocketMode::Single,
4342                        reconnect_delay: 0,
4343                        signature_gen: sig,
4344                        timeout: 10,
4345                        time_unit: None,
4346                        agent: None,
4347                        user_agent: build_user_agent("product"),
4348                    };
4349                    let api = WebsocketApi::new(cfg, vec![conn.clone()]);
4350                    let res = api.clone().connect().await;
4351                    assert!(!matches!(res, Err(WebsocketError::Timeout)));
4352                });
4353            }
4354
4355            #[test]
4356            fn already_connected_returns_ok() {
4357                TOKIO_SHARED_RT.block_on(async {
4358                    let conn = WebsocketConnection::new("id2");
4359                    let (tx, _) = unbounded_channel();
4360                    {
4361                        let mut st = conn.state.lock().await;
4362                        st.ws_write_tx = Some(tx);
4363                    }
4364                    let sig = SignatureGenerator::new(
4365                        Some("api_secret".to_string()),
4366                        None::<PrivateKey>,
4367                        None::<String>,
4368                    );
4369                    let cfg = ConfigurationWebsocketApi {
4370                        api_key: Some("api_key".to_string()),
4371                        api_secret: Some("api_secret".to_string()),
4372                        private_key: None,
4373                        private_key_passphrase: None,
4374                        ws_url: Some("ws://example.com".to_string()),
4375                        mode: WebsocketMode::Single,
4376                        reconnect_delay: 0,
4377                        signature_gen: sig,
4378                        timeout: 10,
4379                        time_unit: None,
4380                        agent: None,
4381                        user_agent: build_user_agent("product"),
4382                    };
4383                    let api = WebsocketApi::new(cfg, vec![conn.clone()]);
4384                    let res = api.connect().await;
4385                    assert!(res.is_ok());
4386                });
4387            }
4388
4389            #[test]
4390            fn not_connected_returns_error() {
4391                TOKIO_SHARED_RT.block_on(async {
4392                    let conn = WebsocketConnection::new("id1");
4393                    let sig = SignatureGenerator::new(
4394                        Some("api_secret".to_string()),
4395                        None::<PrivateKey>,
4396                        None::<String>,
4397                    );
4398                    let cfg = ConfigurationWebsocketApi {
4399                        api_key: Some("api_key".to_string()),
4400                        api_secret: Some("api_secret".to_string()),
4401                        private_key: None,
4402                        private_key_passphrase: None,
4403                        ws_url: Some("ws://127.0.0.1:9".to_string()),
4404                        mode: WebsocketMode::Single,
4405                        reconnect_delay: 0,
4406                        signature_gen: sig,
4407                        timeout: 10,
4408                        time_unit: None,
4409                        agent: None,
4410                        user_agent: build_user_agent("product"),
4411                    };
4412                    let api = WebsocketApi::new(cfg, vec![conn.clone()]);
4413                    let res = api.connect().await;
4414                    assert!(res.is_err());
4415                });
4416            }
4417
4418            #[test]
4419            fn concurrent_calls_both_error_or_ok() {
4420                TOKIO_SHARED_RT.block_on(async {
4421                    let conn = WebsocketConnection::new("id3");
4422                    let sig = SignatureGenerator::new(
4423                        Some("api_secret".to_string()),
4424                        None::<PrivateKey>,
4425                        None::<String>,
4426                    );
4427                    let cfg = ConfigurationWebsocketApi {
4428                        api_key: Some("api_key".to_string()),
4429                        api_secret: Some("api_secret".to_string()),
4430                        private_key: None,
4431                        private_key_passphrase: None,
4432                        ws_url: Some("wss://invalid-domain".to_string()),
4433                        mode: WebsocketMode::Single,
4434                        reconnect_delay: 0,
4435                        signature_gen: sig,
4436                        timeout: 10,
4437                        time_unit: None,
4438                        agent: None,
4439                        user_agent: build_user_agent("product"),
4440                    };
4441                    let api = WebsocketApi::new(cfg, vec![conn.clone()]);
4442                    let fut1 = tokio::spawn(api.clone().connect());
4443                    let fut2 = tokio::spawn(api.clone().connect());
4444                    let r1 = fut1.await.unwrap();
4445                    let r2 = fut2.await.unwrap();
4446
4447                    assert!(r1.is_err());
4448                    assert!(r2.is_err() || r2.is_ok());
4449                });
4450            }
4451
4452            #[test]
4453            fn pool_failure_is_propagated() {
4454                TOKIO_SHARED_RT.block_on(async {
4455                    let conn = WebsocketConnection::new("w");
4456                    let sig = SignatureGenerator::new(
4457                        Some("api_secret".to_string()),
4458                        None::<PrivateKey>,
4459                        None::<String>,
4460                    );
4461                    let cfg = ConfigurationWebsocketApi {
4462                        api_key: Some("api_key".into()),
4463                        api_secret: Some("api_secret".to_string()),
4464                        private_key: None,
4465                        private_key_passphrase: None,
4466                        ws_url: Some("ws://doesnotexist:1".to_string()),
4467                        mode: WebsocketMode::Single,
4468                        reconnect_delay: 0,
4469                        signature_gen: sig,
4470                        timeout: 10,
4471                        time_unit: None,
4472                        agent: None,
4473                        user_agent: build_user_agent("product"),
4474                    };
4475                    let api = WebsocketApi::new(cfg, vec![conn.clone()]);
4476                    let res = api.clone().connect().await;
4477                    match res {
4478                        Err(WebsocketError::Handshake(_) | WebsocketError::Timeout) => {}
4479                        _ => panic!("expected handshake or timeout error"),
4480                    }
4481                });
4482            }
4483        }
4484
4485        mod send_message {
4486            use super::*;
4487
4488            #[test]
4489            fn unsigned_message() {
4490                TOKIO_SHARED_RT.block_on(async {
4491                    let api = create_websocket_api(None);
4492                    let conn = &api.common.connection_pool[0];
4493                    let (tx, mut rx) = unbounded_channel::<Message>();
4494                    {
4495                        let mut st = conn.state.lock().await;
4496                        st.ws_write_tx = Some(tx);
4497                    }
4498
4499                    let fut = tokio::spawn({
4500                        let api = api.clone();
4501                        async move {
4502                            let mut params = BTreeMap::new();
4503                            params.insert("foo".into(), Value::String("bar".into()));
4504                            api.send_message::<Value>(
4505                                "mymethod",
4506                                params,
4507                                WebsocketMessageSendOptions {
4508                                    with_api_key: false,
4509                                    is_signed: false,
4510                                },
4511                            )
4512                            .await
4513                            .unwrap()
4514                        }
4515                    });
4516
4517                    let sent = rx.recv().await.unwrap();
4518                    let Message::Text(txt) = sent else { panic!() };
4519                    let req: Value = serde_json::from_str(&txt).unwrap();
4520                    assert_eq!(req["method"], "mymethod");
4521                    assert!(req["params"]["foo"] == "bar");
4522                    assert!(req["params"].get("apiKey").is_none());
4523                    assert!(req["params"].get("timestamp").is_none());
4524                    assert!(req["params"].get("signature").is_none());
4525
4526                    let id = req["id"].as_str().unwrap().to_string();
4527                    let mut st = conn.state.lock().await;
4528                    let pending = st.pending_requests.remove(&id).unwrap();
4529                    let reply = json!({
4530                        "id": id,
4531                        "result": { "x": 42 },
4532                        "rateLimits": [{ "limit": 7 }]
4533                    });
4534                    pending.completion.send(Ok(reply)).unwrap();
4535
4536                    let resp = fut.await.unwrap();
4537                    let rate_limits = resp.rate_limits.unwrap_or_default();
4538
4539                    assert!(rate_limits.is_empty());
4540                    assert_eq!(resp.raw, json!({"x": 42}));
4541                });
4542            }
4543
4544            #[test]
4545            fn with_api_key_only() {
4546                TOKIO_SHARED_RT.block_on(async {
4547                    let api = create_websocket_api(None);
4548                    let conn = &api.common.connection_pool[0];
4549                    let (tx, mut rx) = unbounded_channel::<Message>();
4550                    {
4551                        let mut st = conn.state.lock().await;
4552                        st.ws_write_tx = Some(tx);
4553                    }
4554
4555                    let fut = tokio::spawn({
4556                        let api = api.clone();
4557                        async move {
4558                            let params = BTreeMap::new();
4559                            api.send_message::<Value>(
4560                                "foo",
4561                                params,
4562                                WebsocketMessageSendOptions {
4563                                    with_api_key: true,
4564                                    is_signed: false,
4565                                },
4566                            )
4567                            .await
4568                            .unwrap()
4569                        }
4570                    });
4571
4572                    let Message::Text(txt) = rx.recv().await.unwrap() else {
4573                        panic!()
4574                    };
4575                    let req: Value = serde_json::from_str(&txt).unwrap();
4576                    assert_eq!(req["params"]["apiKey"], "api_key");
4577
4578                    let id = req["id"].as_str().unwrap().to_string();
4579                    let mut st = conn.state.lock().await;
4580                    let pending = st.pending_requests.remove(&id).unwrap();
4581                    pending
4582                        .completion
4583                        .send(Ok(json!({
4584                            "id": id,
4585                            "result": {},
4586                            "rateLimits": []
4587                        })))
4588                        .unwrap();
4589
4590                    let resp = fut.await.unwrap();
4591
4592                    assert_eq!(resp.raw, json!({}));
4593                    assert!(st.pending_requests.is_empty());
4594                });
4595            }
4596
4597            #[test]
4598            fn signed_message_has_timestamp_and_signature() {
4599                TOKIO_SHARED_RT.block_on(async {
4600                    let api = create_websocket_api(None);
4601                    let conn = &api.common.connection_pool[0];
4602                    let (tx, mut rx) = unbounded_channel::<Message>();
4603                    {
4604                        let mut st = conn.state.lock().await;
4605                        st.ws_write_tx = Some(tx);
4606                    }
4607
4608                    let fut = tokio::spawn({
4609                        let api = api.clone();
4610                        async move {
4611                            let mut params = BTreeMap::new();
4612                            params.insert("foo".into(), Value::String("bar".into()));
4613                            api.send_message::<Value>(
4614                                "method",
4615                                params,
4616                                WebsocketMessageSendOptions {
4617                                    with_api_key: true,
4618                                    is_signed: true,
4619                                },
4620                            )
4621                            .await
4622                            .unwrap()
4623                        }
4624                    });
4625
4626                    let Message::Text(txt) = rx.recv().await.unwrap() else {
4627                        panic!()
4628                    };
4629                    let req: Value = serde_json::from_str(&txt).unwrap();
4630                    let p = &req["params"];
4631                    assert!(p["apiKey"] == "api_key");
4632                    assert!(p["timestamp"].is_number());
4633                    assert!(p["signature"].is_string());
4634
4635                    let id = req["id"].as_str().unwrap().to_string();
4636                    let mut st = conn.state.lock().await;
4637                    let pending = st.pending_requests.remove(&id).unwrap();
4638                    pending
4639                        .completion
4640                        .send(Ok(json!({
4641                            "id": id,
4642                            "result": { "ok": true },
4643                            "rateLimits": []
4644                        })))
4645                        .unwrap();
4646
4647                    let resp = fut.await.unwrap();
4648                    assert_eq!(resp.raw, json!({ "ok": true }));
4649                });
4650            }
4651
4652            #[test]
4653            fn error_if_not_connected() {
4654                TOKIO_SHARED_RT.block_on(async {
4655                    let api = create_websocket_api(None);
4656                    let conn = &api.common.connection_pool[0];
4657                    {
4658                        let mut st = conn.state.lock().await;
4659                        st.ws_write_tx = None;
4660                    }
4661                    let params = BTreeMap::new();
4662                    let err = api
4663                        .send_message::<Value>(
4664                            "method",
4665                            params,
4666                            WebsocketMessageSendOptions {
4667                                with_api_key: false,
4668                                is_signed: false,
4669                            },
4670                        )
4671                        .await
4672                        .unwrap_err();
4673                    matches!(err, WebsocketError::NotConnected);
4674                });
4675            }
4676        }
4677
4678        mod prepare_url {
4679            use super::*;
4680
4681            #[test]
4682            fn no_time_unit() {
4683                TOKIO_SHARED_RT.block_on(async {
4684                    let api = create_websocket_api(None);
4685                    let url = "wss://example.com/ws".to_string();
4686                    assert_eq!(api.prepare_url(&url), url);
4687                });
4688            }
4689
4690            #[test]
4691            fn appends_time_unit() {
4692                TOKIO_SHARED_RT.block_on(async {
4693                    let api = create_websocket_api(Some(TimeUnit::Millisecond));
4694                    let base = "wss://example.com/ws".to_string();
4695                    let got = api.prepare_url(&base);
4696                    assert_eq!(got, format!("{base}?timeUnit=millisecond"));
4697                });
4698            }
4699
4700            #[test]
4701            fn handles_existing_query() {
4702                TOKIO_SHARED_RT.block_on(async {
4703                    let api = create_websocket_api(Some(TimeUnit::Microsecond));
4704                    let base = "wss://example.com/ws?foo=bar".to_string();
4705                    let got = api.prepare_url(&base);
4706                    assert_eq!(got, format!("{base}&timeUnit=microsecond"));
4707                });
4708            }
4709        }
4710
4711        mod on_message {
4712            use super::*;
4713
4714            fn create_websocket_api_and_conn() -> (Arc<WebsocketApi>, Arc<WebsocketConnection>) {
4715                let sig_gen = SignatureGenerator::new(
4716                    Some("api_secret".to_string()),
4717                    None::<_>,
4718                    None::<String>,
4719                );
4720                let config = ConfigurationWebsocketApi {
4721                    api_key: Some("api_key".to_string()),
4722                    api_secret: Some("api_secret".to_string()),
4723                    private_key: None,
4724                    private_key_passphrase: None,
4725                    ws_url: Some("wss://example".to_string()),
4726                    mode: WebsocketMode::Single,
4727                    reconnect_delay: 0,
4728                    signature_gen: sig_gen,
4729                    timeout: 1000,
4730                    time_unit: None,
4731                    agent: None,
4732                    user_agent: build_user_agent("product"),
4733                };
4734                let conn = WebsocketConnection::new("test");
4735                let api = WebsocketApi::new(config, vec![conn.clone()]);
4736                (api, conn)
4737            }
4738
4739            #[test]
4740            fn resolves_pending_and_removes_request() {
4741                TOKIO_SHARED_RT.block_on(async {
4742                    let (api, conn) = create_websocket_api_and_conn();
4743                    let (tx, rx) = oneshot::channel();
4744                    {
4745                        let mut st = conn.state.lock().await;
4746                        st.pending_requests
4747                            .insert("id1".to_string(), PendingRequest { completion: tx });
4748                    }
4749                    let msg = json!({"id":"id1","status":200,"foo":"bar"});
4750                    api.on_message(msg.to_string(), conn.clone()).await;
4751                    let got = rx.await.unwrap().unwrap();
4752                    assert_eq!(got, msg);
4753                    let st = conn.state.lock().await;
4754                    assert!(!st.pending_requests.contains_key("id1"));
4755                });
4756            }
4757
4758            #[test]
4759            fn uses_result_when_present() {
4760                TOKIO_SHARED_RT.block_on(async {
4761                    let (api, conn) = create_websocket_api_and_conn();
4762                    let (tx, rx) = oneshot::channel();
4763                    {
4764                        let mut st = conn.state.lock().await;
4765                        st.pending_requests
4766                            .insert("id1".to_string(), PendingRequest { completion: tx });
4767                    }
4768                    let msg = json!({
4769                        "id": "id1",
4770                        "status": 200,
4771                        "response": [1,2],
4772                        "result": {"a":1}
4773                    });
4774                    api.on_message(msg.to_string(), conn.clone()).await;
4775                    let got = rx.await.unwrap().unwrap();
4776                    assert_eq!(got.get("result").unwrap(), &json!({"a":1}));
4777                });
4778            }
4779
4780            #[test]
4781            fn uses_response_when_no_result() {
4782                TOKIO_SHARED_RT.block_on(async {
4783                    let (api, conn) = create_websocket_api_and_conn();
4784                    let (tx, rx) = oneshot::channel();
4785                    {
4786                        let mut st = conn.state.lock().await;
4787                        st.pending_requests
4788                            .insert("id1".to_string(), PendingRequest { completion: tx });
4789                    }
4790                    let msg = json!({
4791                        "id": "id1",
4792                        "status": 200,
4793                        "response": ["ok"]
4794                    });
4795                    api.on_message(msg.to_string(), conn.clone()).await;
4796                    let got = rx.await.unwrap().unwrap();
4797                    assert_eq!(got.get("response").unwrap(), &json!(["ok"]));
4798                });
4799            }
4800
4801            #[test]
4802            fn errors_for_status_ge_400() {
4803                TOKIO_SHARED_RT.block_on(async {
4804                    let (api, conn) = create_websocket_api_and_conn();
4805                    let (tx, rx) = oneshot::channel();
4806                    {
4807                        let mut st = conn.state.lock().await;
4808                        st.pending_requests
4809                            .insert("bad".to_string(), PendingRequest { completion: tx });
4810                    }
4811                    let err_obj = json!({"code":123,"msg":"oops"});
4812                    let msg = json!({"id":"bad","status":500,"error":err_obj});
4813                    api.on_message(msg.to_string(), conn.clone()).await;
4814                    match rx.await.unwrap() {
4815                        Err(WebsocketError::ResponseError { code, message }) => {
4816                            assert_eq!(code, 123);
4817                            assert_eq!(message, "oops");
4818                        }
4819                        other => panic!("expected ResponseError, got {other:?}"),
4820                    }
4821                    let st = conn.state.lock().await;
4822                    assert!(!st.pending_requests.contains_key("bad"));
4823                });
4824            }
4825
4826            #[test]
4827            fn ignores_unknown_id() {
4828                TOKIO_SHARED_RT.block_on(async {
4829                    let (api, conn) = create_websocket_api_and_conn();
4830                    let msg = json!({"id":"nope","status":200});
4831                    api.on_message(msg.to_string(), conn.clone()).await;
4832                    let st = conn.state.lock().await;
4833                    assert!(st.pending_requests.is_empty());
4834                });
4835            }
4836
4837            #[test]
4838            fn parse_error_ignored() {
4839                TOKIO_SHARED_RT.block_on(async {
4840                    let (api, conn) = create_websocket_api_and_conn();
4841                    api.on_message("not json".to_string(), conn.clone()).await;
4842                    let st = conn.state.lock().await;
4843                    assert!(st.pending_requests.is_empty());
4844                });
4845            }
4846
4847            #[test]
4848            fn error_status_sends_error() {
4849                TOKIO_SHARED_RT.block_on(async {
4850                    let (api, conn) = create_websocket_api_and_conn();
4851                    let (tx, rx) = oneshot::channel();
4852                    {
4853                        let mut st = conn.state.lock().await;
4854                        st.pending_requests
4855                            .insert("err".to_string(), PendingRequest { completion: tx });
4856                    }
4857                    let msg = json!({
4858                        "id": "err",
4859                        "status": 500,
4860                        "error": { "code": 42, "msg": "Bad!" }
4861                    });
4862                    api.on_message(msg.to_string(), conn.clone()).await;
4863                    match rx.await.unwrap() {
4864                        Err(WebsocketError::ResponseError { code, message }) => {
4865                            assert_eq!(code, 42);
4866                            assert_eq!(message, "Bad!");
4867                        }
4868                        other => panic!("expected ResponseError, got {other:?}"),
4869                    }
4870                });
4871            }
4872
4873            #[test]
4874            fn unknown_id_logs_warning_and_leaves_pending() {
4875                TOKIO_SHARED_RT.block_on(async {
4876                    let (api, conn) = create_websocket_api_and_conn();
4877                    {
4878                        let mut st = conn.state.lock().await;
4879                        st.pending_requests.insert(
4880                            "keep".to_string(),
4881                            PendingRequest {
4882                                completion: oneshot::channel().0,
4883                            },
4884                        );
4885                    }
4886                    api.on_message(
4887                        json!({ "id": "foo", "status": 200, "result": 1 }).to_string(),
4888                        conn.clone(),
4889                    )
4890                    .await;
4891                    let st = conn.state.lock().await;
4892                    assert!(st.pending_requests.contains_key("keep"));
4893                });
4894            }
4895        }
4896    }
4897
4898    mod websocket_streams {
4899        use super::*;
4900
4901        mod initialisation {
4902            use super::*;
4903
4904            #[test]
4905            fn new_initializes_fields() {
4906                TOKIO_SHARED_RT.block_on(async {
4907                    let sig_gen = SignatureGenerator::new(
4908                        Some("api_secret".to_string()),
4909                        None::<PrivateKey>,
4910                        None::<String>,
4911                    );
4912                    let config = ConfigurationWebsocketApi {
4913                        api_key: Some("api_key".to_string()),
4914                        api_secret: Some("api_secret".to_string()),
4915                        private_key: None,
4916                        private_key_passphrase: None,
4917                        ws_url: Some("wss://example".to_string()),
4918                        mode: WebsocketMode::Single,
4919                        reconnect_delay: 1000,
4920                        signature_gen: sig_gen.clone(),
4921                        timeout: 500,
4922                        time_unit: None,
4923                        agent: None,
4924                        user_agent: build_user_agent("product"),
4925                    };
4926                    let conn1 = WebsocketConnection::new("c1");
4927                    let conn2 = WebsocketConnection::new("c2");
4928                    let api = WebsocketApi::new(config.clone(), vec![conn1.clone(), conn2.clone()]);
4929
4930                    assert_eq!(api.common.connection_pool.len(), 2);
4931                    assert!(Arc::ptr_eq(&api.common.connection_pool[0], &conn1));
4932                    assert!(Arc::ptr_eq(&api.common.connection_pool[1], &conn2));
4933                    assert_eq!(api.configuration.ws_url, Some("wss://example".to_string()));
4934                    let flag = api.is_connecting.lock().await;
4935                    assert!(!*flag);
4936                });
4937            }
4938        }
4939
4940        mod connect {
4941            use super::*;
4942
4943            #[test]
4944            fn establishes_successfully() {
4945                TOKIO_SHARED_RT.block_on(async {
4946                    let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4947                    let port = listener.local_addr().unwrap().port();
4948
4949                    tokio::spawn(async move {
4950                        for _ in 0..2 {
4951                            if let Ok((stream, _)) = listener.accept().await {
4952                                let mut ws = accept_async(stream).await.unwrap();
4953                                ws.close(None).await.ok();
4954                            }
4955                        }
4956                    });
4957
4958                    let create_websocket_streams = |ws_url: &str| {
4959                        let c1 = WebsocketConnection::new("c1");
4960                        let c2 = WebsocketConnection::new("c2");
4961                        let config = ConfigurationWebsocketStreams {
4962                            ws_url: Some(ws_url.to_string()),
4963                            mode: WebsocketMode::Pool(2),
4964                            reconnect_delay: 500,
4965                            time_unit: None,
4966                            agent: None,
4967                            user_agent: build_user_agent("product"),
4968                        };
4969                        WebsocketStreams::new(config, vec![c1, c2])
4970                    };
4971
4972                    let url = format!("ws://127.0.0.1:{port}");
4973                    let ws = create_websocket_streams(&url);
4974
4975                    let res = ws.connect(vec!["stream1".into()]).await;
4976                    assert!(res.is_ok());
4977                });
4978            }
4979
4980            #[test]
4981            fn refused_returns_error() {
4982                TOKIO_SHARED_RT.block_on(async {
4983                    let ws = create_websocket_streams(Some("ws://127.0.0.1:9"), None);
4984                    let res = ws.connect(vec!["stream1".into()]).await;
4985                    assert!(res.is_err());
4986                });
4987            }
4988
4989            #[test]
4990            fn invalid_url_returns_error() {
4991                TOKIO_SHARED_RT.block_on(async {
4992                    let ws = create_websocket_streams(Some("not-a-url"), None);
4993                    let res = ws.connect(vec!["s".into()]).await;
4994                    assert!(res.is_err());
4995                });
4996            }
4997        }
4998
4999        mod disconnect {
5000            use super::*;
5001
5002            #[test]
5003            fn disconnect_clears_state_and_streams() {
5004                TOKIO_SHARED_RT.block_on(async {
5005                    let ws = create_websocket_streams(None, None);
5006                    let conn = &ws.common.connection_pool[0];
5007                    {
5008                        let mut state = conn.state.lock().await;
5009                        state.stream_callbacks.insert("s1".to_string(), Vec::new());
5010                        state.pending_subscriptions.push_back("s2".to_string());
5011                    }
5012                    {
5013                        let mut map = ws.connection_streams.lock().await;
5014                        map.insert("s3".to_string(), Arc::clone(conn));
5015                    }
5016
5017                    let res = ws.disconnect().await;
5018                    assert!(res.is_ok());
5019
5020                    let state = conn.state.lock().await;
5021                    assert!(state.stream_callbacks.is_empty());
5022                    assert!(state.pending_subscriptions.is_empty());
5023
5024                    let map = ws.connection_streams.lock().await;
5025                    assert!(map.is_empty());
5026                });
5027            }
5028        }
5029
5030        mod subscribe {
5031            use super::*;
5032
5033            #[test]
5034            fn empty_list_does_nothing() {
5035                TOKIO_SHARED_RT.block_on(async {
5036                    let ws = create_websocket_streams(None, None);
5037                    ws.clone().subscribe(Vec::new(), None).await;
5038                    let map = ws.connection_streams.lock().await;
5039                    assert!(map.is_empty());
5040                });
5041            }
5042
5043            #[test]
5044            fn queue_when_not_ready() {
5045                TOKIO_SHARED_RT.block_on(async {
5046                    let ws = create_websocket_streams(None, None);
5047                    let conn = ws.common.connection_pool[0].clone();
5048                    ws.clone().subscribe(vec!["s1".into()], None).await;
5049                    let state = conn.state.lock().await;
5050                    let pending: Vec<String> =
5051                        state.pending_subscriptions.iter().cloned().collect();
5052                    assert_eq!(pending, vec!["s1".to_string()]);
5053                });
5054            }
5055
5056            #[test]
5057            fn only_one_subscription_per_stream() {
5058                TOKIO_SHARED_RT.block_on(async {
5059                    let ws = create_websocket_streams(None, None);
5060                    let conn = ws.common.connection_pool[0].clone();
5061                    ws.clone().subscribe(vec!["s1".into()], None).await;
5062                    ws.clone().subscribe(vec!["s1".into()], None).await;
5063                    let state = conn.state.lock().await;
5064                    let pending: Vec<String> =
5065                        state.pending_subscriptions.iter().cloned().collect();
5066                    assert_eq!(pending, vec!["s1".to_string()]);
5067                });
5068            }
5069
5070            #[test]
5071            fn multiple_streams_assigned() {
5072                TOKIO_SHARED_RT.block_on(async {
5073                    let ws = create_websocket_streams(None, None);
5074                    ws.clone()
5075                        .subscribe(vec!["s1".into(), "s2".into()], None)
5076                        .await;
5077                    let map = ws.connection_streams.lock().await;
5078                    assert!(map.contains_key("s1"));
5079                    assert!(map.contains_key("s2"));
5080                });
5081            }
5082
5083            #[test]
5084            fn existing_stream_not_reassigned() {
5085                TOKIO_SHARED_RT.block_on(async {
5086                    let ws = create_websocket_streams(None, None);
5087                    ws.clone().subscribe(vec!["s1".into()], None).await;
5088                    let first_id = {
5089                        let map = ws.connection_streams.lock().await;
5090                        map.get("s1").unwrap().id.clone()
5091                    };
5092                    ws.clone()
5093                        .subscribe(vec!["s1".into(), "s2".into()], None)
5094                        .await;
5095                    let map = ws.connection_streams.lock().await;
5096                    let second_id = map.get("s1").unwrap().id.clone();
5097                    assert_eq!(first_id, second_id);
5098                    assert!(map.contains_key("s2"));
5099                });
5100            }
5101        }
5102
5103        mod unsubscribe {
5104            use super::*;
5105
5106            #[test]
5107            fn removes_stream_with_no_callbacks() {
5108                TOKIO_SHARED_RT.block_on(async {
5109                    let ws = create_websocket_streams(None, None);
5110                    let conn = ws.common.connection_pool[0].clone();
5111
5112                    {
5113                        let (tx, _rx) = unbounded_channel::<Message>();
5114                        let mut st = conn.state.lock().await;
5115                        st.ws_write_tx = Some(tx);
5116                    }
5117
5118                    {
5119                        let mut map = ws.connection_streams.lock().await;
5120                        map.insert("s1".to_string(), conn.clone());
5121                    }
5122                    {
5123                        let mut st = conn.state.lock().await;
5124                        st.stream_callbacks.insert("s1".to_string(), Vec::new());
5125                    }
5126
5127                    ws.unsubscribe(vec!["s1".to_string()], None).await;
5128
5129                    assert!(!ws.connection_streams.lock().await.contains_key("s1"));
5130                    assert!(!conn.state.lock().await.stream_callbacks.contains_key("s1"));
5131                });
5132            }
5133
5134            #[test]
5135            fn preserves_stream_with_callbacks() {
5136                TOKIO_SHARED_RT.block_on(async {
5137                    let ws = create_websocket_streams(None, None);
5138                    let conn = ws.common.connection_pool[1].clone();
5139
5140                    {
5141                        let mut map = ws.connection_streams.lock().await;
5142                        map.insert("s2".to_string(), conn.clone());
5143                    }
5144                    {
5145                        let mut state = conn.state.lock().await;
5146                        state
5147                            .stream_callbacks
5148                            .insert("s2".to_string(), vec![Arc::new(|_: &Value| {})]);
5149                    }
5150
5151                    ws.unsubscribe(vec!["s2".to_string()], None).await;
5152
5153                    assert!(ws.connection_streams.lock().await.contains_key("s2"));
5154                    assert!(conn.state.lock().await.stream_callbacks.contains_key("s2"));
5155                });
5156            }
5157
5158            #[test]
5159            fn does_not_send_if_callbacks_exist() {
5160                TOKIO_SHARED_RT.block_on(async {
5161                    let ws = create_websocket_streams(None, None);
5162                    let conn = ws.common.connection_pool[0].clone();
5163                    {
5164                        let mut map = ws.connection_streams.lock().await;
5165                        map.insert("s1".to_string(), conn.clone());
5166                    }
5167                    {
5168                        let mut state = conn.state.lock().await;
5169                        state.stream_callbacks.insert(
5170                            "s1".to_string(),
5171                            vec![Arc::new(|_: &Value| {}), Arc::new(|_: &Value| {})],
5172                        );
5173                    }
5174                    ws.unsubscribe(vec!["s1".into()], None).await;
5175                    assert!(ws.connection_streams.lock().await.contains_key("s1"));
5176                    assert!(conn.state.lock().await.stream_callbacks.contains_key("s1"));
5177                });
5178            }
5179
5180            #[test]
5181            fn warns_if_not_associated() {
5182                TOKIO_SHARED_RT.block_on(async {
5183                    let ws = create_websocket_streams(None, None);
5184                    ws.unsubscribe(vec!["nope".into()], None).await;
5185                });
5186            }
5187
5188            #[test]
5189            fn empty_list_does_nothing() {
5190                TOKIO_SHARED_RT.block_on(async {
5191                    let ws = create_websocket_streams(None, None);
5192                    let before = ws.connection_streams.lock().await.len();
5193                    ws.unsubscribe(Vec::<String>::new(), None).await;
5194                    let after = ws.connection_streams.lock().await.len();
5195                    assert_eq!(before, after);
5196                });
5197            }
5198
5199            #[test]
5200            fn invalid_custom_id_falls_back() {
5201                TOKIO_SHARED_RT.block_on(async {
5202                    let ws = create_websocket_streams(None, None);
5203                    let conn = ws.common.connection_pool[0].clone();
5204                    {
5205                        let mut map = ws.connection_streams.lock().await;
5206                        map.insert("foo".to_string(), conn.clone());
5207                    }
5208                    {
5209                        let mut state = conn.state.lock().await;
5210                        let (tx, _rx) = unbounded_channel();
5211                        state.ws_write_tx = Some(tx);
5212                        state.stream_callbacks.insert("foo".to_string(), Vec::new());
5213                    }
5214                    ws.unsubscribe(vec!["foo".into()], Some("bad-id".into()))
5215                        .await;
5216                    assert!(!ws.connection_streams.lock().await.contains_key("foo"));
5217                });
5218            }
5219
5220            #[test]
5221            fn removes_even_without_write_channel() {
5222                TOKIO_SHARED_RT.block_on(async {
5223                    let ws = create_websocket_streams(None, None);
5224                    let conn = ws.common.connection_pool[0].clone();
5225                    {
5226                        let mut map = ws.connection_streams.lock().await;
5227                        map.insert("x".to_string(), conn.clone());
5228                    }
5229                    {
5230                        let mut state = conn.state.lock().await;
5231                        let (tx, _rx) = unbounded_channel();
5232                        state.ws_write_tx = Some(tx);
5233                        state.stream_callbacks.insert("x".to_string(), Vec::new());
5234                    }
5235                    ws.unsubscribe(vec!["x".into()], None).await;
5236                    assert!(!ws.connection_streams.lock().await.contains_key("x"));
5237                });
5238            }
5239        }
5240
5241        mod is_subscribed {
5242            use super::*;
5243
5244            #[test]
5245            fn returns_false_when_not_subscribed() {
5246                TOKIO_SHARED_RT.block_on(async {
5247                    let ws = create_websocket_streams(None, None);
5248                    assert!(!ws.is_subscribed("unknown").await);
5249                });
5250            }
5251
5252            #[test]
5253            fn returns_true_when_subscribed() {
5254                TOKIO_SHARED_RT.block_on(async {
5255                    let ws = create_websocket_streams(None, None);
5256                    let conn = ws.common.connection_pool[0].clone();
5257                    {
5258                        let mut map = ws.connection_streams.lock().await;
5259                        map.insert("stream1".to_string(), conn);
5260                    }
5261                    assert!(ws.is_subscribed("stream1").await);
5262                });
5263            }
5264        }
5265
5266        mod prepare_url {
5267            use super::*;
5268
5269            #[test]
5270            fn without_time_unit_returns_base_url() {
5271                TOKIO_SHARED_RT.block_on(async {
5272                    let conns = vec![
5273                        WebsocketConnection::new("c1"),
5274                        WebsocketConnection::new("c2"),
5275                    ];
5276                    let config = ConfigurationWebsocketStreams {
5277                        ws_url: Some("wss://example".to_string()),
5278                        mode: WebsocketMode::Single,
5279                        reconnect_delay: 100,
5280                        time_unit: None,
5281                        agent: None,
5282                        user_agent: build_user_agent("product"),
5283                    };
5284                    let ws = WebsocketStreams::new(config, conns);
5285                    let url = ws.prepare_url(&["s1".into(), "s2".into()]);
5286                    assert_eq!(url, "wss://example/stream?streams=s1/s2");
5287                });
5288            }
5289
5290            #[test]
5291            fn with_time_unit_appends_parameter() {
5292                TOKIO_SHARED_RT.block_on(async {
5293                    let conns = vec![WebsocketConnection::new("c1")];
5294                    let config = ConfigurationWebsocketStreams {
5295                        ws_url: Some("wss://example".to_string()),
5296                        mode: WebsocketMode::Single,
5297                        reconnect_delay: 100,
5298                        time_unit: Some(TimeUnit::Millisecond),
5299                        agent: None,
5300                        user_agent: build_user_agent("product"),
5301                    };
5302                    let ws = WebsocketStreams::new(config, conns);
5303                    let url = ws.prepare_url(&["a".into()]);
5304                    assert_eq!(url, "wss://example/stream?streams=a&timeUnit=millisecond");
5305                });
5306            }
5307
5308            #[test]
5309            fn multiple_streams_and_time_unit() {
5310                TOKIO_SHARED_RT.block_on(async {
5311                    let conns = vec![WebsocketConnection::new("c1")];
5312                    let config = ConfigurationWebsocketStreams {
5313                        ws_url: Some("wss://example".to_string()),
5314                        mode: WebsocketMode::Single,
5315                        reconnect_delay: 100,
5316                        time_unit: Some(TimeUnit::Microsecond),
5317                        agent: None,
5318                        user_agent: build_user_agent("product"),
5319                    };
5320                    let ws = WebsocketStreams::new(config, conns);
5321                    let url = ws.prepare_url(&["x".into(), "y".into(), "z".into()]);
5322                    assert_eq!(
5323                        url,
5324                        "wss://example/stream?streams=x/y/z&timeUnit=microsecond"
5325                    );
5326                });
5327            }
5328        }
5329
5330        mod handle_stream_assignment {
5331            use super::*;
5332
5333            #[test]
5334            fn assigns_new_streams_to_connections() {
5335                TOKIO_SHARED_RT.block_on(async {
5336                    let ws = create_websocket_streams(None, None);
5337                    let groups = ws
5338                        .clone()
5339                        .handle_stream_assignment(vec!["s1".into(), "s2".into()])
5340                        .await;
5341                    let mut seen_streams = HashSet::new();
5342                    for (_conn, streams) in &groups {
5343                        for s in streams {
5344                            seen_streams.insert(s);
5345                        }
5346                    }
5347                    assert_eq!(
5348                        seen_streams,
5349                        ["s1".to_string(), "s2".to_string()].iter().collect()
5350                    );
5351                    assert_eq!(groups.len(), 1);
5352                });
5353            }
5354
5355            #[test]
5356            fn reuses_existing_connection_for_duplicate_stream() {
5357                TOKIO_SHARED_RT.block_on(async {
5358                    let ws = create_websocket_streams(None, None);
5359                    let _ = ws.clone().handle_stream_assignment(vec!["s1".into()]).await;
5360                    let groups = ws
5361                        .clone()
5362                        .handle_stream_assignment(vec!["s1".into(), "s3".into()])
5363                        .await;
5364                    let mut all_streams = Vec::new();
5365                    for (_conn, streams) in groups {
5366                        all_streams.extend(streams);
5367                    }
5368                    all_streams.sort();
5369                    assert_eq!(all_streams, vec!["s1".to_string(), "s3".to_string()]);
5370                });
5371            }
5372
5373            #[test]
5374            fn empty_stream_list_returns_empty() {
5375                TOKIO_SHARED_RT.block_on(async {
5376                    let ws = create_websocket_streams(None, None);
5377                    let groups = ws.clone().handle_stream_assignment(vec![]).await;
5378                    assert!(groups.is_empty());
5379                });
5380            }
5381
5382            #[test]
5383            fn closed_or_reconnecting_forces_reassignment_of_stream() {
5384                TOKIO_SHARED_RT.block_on(async {
5385                    let ws = create_websocket_streams(None, None);
5386                    let mut groups = ws.clone().handle_stream_assignment(vec!["s1".into()]).await;
5387                    let (conn, _) = groups.pop().unwrap();
5388                    {
5389                        let mut st = conn.state.lock().await;
5390                        st.close_initiated = true;
5391                    }
5392                    let groups2 = ws.clone().handle_stream_assignment(vec!["s2".into()]).await;
5393                    assert_eq!(groups2.len(), 1);
5394                    let (_new_conn, streams) = &groups2[0];
5395                    assert_eq!(streams, &vec!["s2".to_string()]);
5396                });
5397            }
5398
5399            #[test]
5400            fn no_available_connections_falls_back_to_one() {
5401                TOKIO_SHARED_RT.block_on(async {
5402                    let ws = create_websocket_streams(None, Some(vec![]));
5403                    let assigned = ws.handle_stream_assignment(vec!["foo".into()]).await;
5404                    assert_eq!(assigned.len(), 1);
5405                    let (_conn, streams) = &assigned[0];
5406                    assert_eq!(streams.as_slice(), &["foo".to_string()]);
5407                });
5408            }
5409
5410            #[test]
5411            fn single_connection_groups_multiple_streams() {
5412                TOKIO_SHARED_RT.block_on(async {
5413                    let conn = WebsocketConnection::new("c1");
5414                    let ws = create_websocket_streams(None, Some(vec![conn.clone()]));
5415                    let assigned = ws
5416                        .handle_stream_assignment(vec!["s1".into(), "s2".into()])
5417                        .await;
5418                    assert_eq!(assigned.len(), 1);
5419                    let (assigned_conn, streams) = &assigned[0];
5420                    assert!(Arc::ptr_eq(assigned_conn, &conn));
5421                    assert_eq!(streams.len(), 2);
5422                    assert!(streams.contains(&"s1".to_string()));
5423                    assert!(streams.contains(&"s2".to_string()));
5424                });
5425            }
5426
5427            #[test]
5428            fn reuse_existing_healthy_connection() {
5429                TOKIO_SHARED_RT.block_on(async {
5430                    let conn = WebsocketConnection::new("c");
5431                    let ws = create_websocket_streams(None, Some(vec![conn.clone()]));
5432                    let _ = ws.handle_stream_assignment(vec!["s1".into()]).await;
5433                    let second = ws.handle_stream_assignment(vec!["s1".into()]).await;
5434                    assert_eq!(second.len(), 1);
5435                    let (assigned_conn, streams) = &second[0];
5436                    assert!(Arc::ptr_eq(assigned_conn, &conn));
5437                    assert_eq!(streams.as_slice(), &["s1".to_string()]);
5438                });
5439            }
5440
5441            #[test]
5442            fn mix_new_and_assigned_streams() {
5443                TOKIO_SHARED_RT.block_on(async {
5444                    let conn = WebsocketConnection::new("c");
5445                    let ws = create_websocket_streams(None, Some(vec![conn.clone()]));
5446                    let _ = ws
5447                        .handle_stream_assignment(vec!["s1".into(), "s2".into()])
5448                        .await;
5449                    let mixed = ws
5450                        .handle_stream_assignment(vec!["s2".into(), "s3".into()])
5451                        .await;
5452                    assert_eq!(mixed.len(), 1);
5453                    let (assigned_conn, streams) = &mixed[0];
5454                    assert!(Arc::ptr_eq(assigned_conn, &conn));
5455                    let mut got = streams.clone();
5456                    got.sort();
5457                    assert_eq!(got, vec!["s2".to_string(), "s3".to_string()]);
5458                });
5459            }
5460        }
5461
5462        mod send_subscription_payload {
5463            use super::*;
5464
5465            #[test]
5466            fn subscribe_payload_with_custom_id_fallbacks_if_invalid() {
5467                TOKIO_SHARED_RT.block_on(async {
5468                    let ws: Arc<WebsocketStreams> =
5469                        create_websocket_streams(Some("ws://example.com"), None);
5470                    let conn = &ws.common.connection_pool[0];
5471                    let (tx, mut rx) = unbounded_channel();
5472                    {
5473                        let mut st = conn.state.lock().await;
5474                        st.ws_write_tx = Some(tx);
5475                    }
5476                    ws.send_subscription_payload(
5477                        conn,
5478                        &vec!["s1".to_string()],
5479                        Some("badid".to_string()),
5480                    );
5481                    let msg = rx.recv().await.expect("no message sent");
5482                    if let Message::Text(txt) = msg {
5483                        let v: serde_json::Value = serde_json::from_str(&txt).unwrap();
5484                        assert_eq!(v["method"], "SUBSCRIBE");
5485                        let id = v["id"].as_str().unwrap();
5486                        assert_ne!(id, "badid");
5487                        assert!(Regex::new(r"^[0-9a-fA-F]{32}$").unwrap().is_match(id));
5488                    } else {
5489                        panic!("unexpected message: {msg:?}");
5490                    }
5491                });
5492            }
5493
5494            #[test]
5495            fn subscribe_payload_with_and_without_custom_id() {
5496                TOKIO_SHARED_RT.block_on(async {
5497                    let ws: Arc<WebsocketStreams> =
5498                        create_websocket_streams(Some("ws://unused"), None);
5499                    let conn = &ws.common.connection_pool[0];
5500                    let (tx, mut rx) = unbounded_channel();
5501                    {
5502                        let mut st = conn.state.lock().await;
5503                        st.ws_write_tx = Some(tx);
5504                    }
5505                    ws.send_subscription_payload(
5506                        conn,
5507                        &vec!["a".to_string(), "b".to_string()],
5508                        Some("deadbeefdeadbeefdeadbeefdeadbeef".to_string()),
5509                    );
5510                    let msg1 = rx.recv().await.unwrap();
5511                    ws.send_subscription_payload(conn, &vec!["x".to_string()], None);
5512                    let msg2 = rx.recv().await.unwrap();
5513
5514                    if let Message::Text(txt1) = msg1 {
5515                        let v1: serde_json::Value = serde_json::from_str(&txt1).unwrap();
5516                        assert_eq!(v1["id"], "deadbeefdeadbeefdeadbeefdeadbeef");
5517                        assert_eq!(
5518                            v1["params"].as_array().unwrap(),
5519                            &vec![serde_json::json!("a"), serde_json::json!("b")]
5520                        );
5521                    } else {
5522                        panic!()
5523                    }
5524
5525                    if let Message::Text(txt2) = msg2 {
5526                        let v2: serde_json::Value = serde_json::from_str(&txt2).unwrap();
5527                        assert_eq!(v2["method"], "SUBSCRIBE");
5528                        let params = v2["params"].as_array().unwrap();
5529                        assert_eq!(params.len(), 1);
5530                        assert_eq!(params[0], "x");
5531                        let id2 = v2["id"].as_str().unwrap();
5532                        assert!(Regex::new(r"^[0-9a-fA-F]{32}$").unwrap().is_match(id2));
5533                    } else {
5534                        panic!()
5535                    }
5536                });
5537            }
5538        }
5539
5540        mod on_open {
5541            use super::*;
5542
5543            #[test]
5544            fn sends_pending_subscriptions() {
5545                TOKIO_SHARED_RT.block_on(async {
5546                    let ws: Arc<WebsocketStreams> =
5547                        create_websocket_streams(Some("ws://example.com"), None);
5548                    let conn = &ws.common.connection_pool[0];
5549                    let (tx, mut rx) = unbounded_channel();
5550                    {
5551                        let mut st = conn.state.lock().await;
5552                        st.ws_write_tx = Some(tx);
5553                        st.pending_subscriptions.push_back("foo".to_string());
5554                        st.pending_subscriptions.push_back("bar".to_string());
5555                    }
5556                    ws.on_open("ws://example.com".to_string(), conn.clone())
5557                        .await;
5558                    let msg = rx.recv().await.expect("no subscription sent");
5559                    if let Message::Text(txt) = msg {
5560                        let v: Value = serde_json::from_str(&txt).unwrap();
5561                        assert_eq!(v["method"], "SUBSCRIBE");
5562                        let params = v["params"].as_array().unwrap();
5563                        assert_eq!(
5564                            params,
5565                            &vec![Value::String("foo".into()), Value::String("bar".into())]
5566                        );
5567                    } else {
5568                        panic!("unexpected message: {msg:?}");
5569                    }
5570                    let st_after = conn.state.lock().await;
5571                    assert!(st_after.pending_subscriptions.is_empty());
5572                });
5573            }
5574
5575            #[test]
5576            fn with_no_pending_subscriptions_sends_nothing() {
5577                TOKIO_SHARED_RT.block_on(async {
5578                    let ws: Arc<WebsocketStreams> =
5579                        create_websocket_streams(Some("ws://example.com"), None);
5580                    let conn = &ws.common.connection_pool[0];
5581                    let (tx, mut rx) = unbounded_channel();
5582                    {
5583                        let mut st = conn.state.lock().await;
5584                        st.ws_write_tx = Some(tx);
5585                    }
5586                    ws.on_open("ws://example.com".to_string(), conn.clone())
5587                        .await;
5588                    assert!(rx.try_recv().is_err(), "unexpected message sent");
5589                });
5590            }
5591
5592            #[test]
5593            fn clears_pending_without_write_channel() {
5594                TOKIO_SHARED_RT.block_on(async {
5595                    let ws: Arc<WebsocketStreams> =
5596                        create_websocket_streams(Some("ws://example.com"), None);
5597                    let conn = &ws.common.connection_pool[0];
5598                    {
5599                        let mut st = conn.state.lock().await;
5600                        st.pending_subscriptions.push_back("solo".to_string());
5601                    }
5602                    ws.on_open("ws://example.com".to_string(), conn.clone())
5603                        .await;
5604                    let st_after = conn.state.lock().await;
5605                    assert!(st_after.pending_subscriptions.is_empty());
5606                });
5607            }
5608        }
5609
5610        mod on_message {
5611            use super::*;
5612
5613            #[test]
5614            fn invokes_registered_callback() {
5615                TOKIO_SHARED_RT.block_on(async {
5616                    let ws: Arc<WebsocketStreams> =
5617                        create_websocket_streams(Some("ws://example.com"), None);
5618                    let conn = &ws.common.connection_pool[0];
5619                    let called = Arc::new(AtomicBool::new(false));
5620                    let called_clone = called.clone();
5621
5622                    {
5623                        let mut st = conn.state.lock().await;
5624                        st.stream_callbacks
5625                            .entry("stream1".to_string())
5626                            .or_default()
5627                            .push(
5628                                (Box::new(move |_: &Value| {
5629                                    called_clone.store(true, Ordering::SeqCst);
5630                                })
5631                                    as Box<dyn Fn(&Value) + Send + Sync>)
5632                                    .into(),
5633                            );
5634                    }
5635
5636                    let msg = json!({
5637                        "stream": "stream1",
5638                        "data": { "key": "value" }
5639                    })
5640                    .to_string();
5641
5642                    ws.on_message(msg, conn.clone()).await;
5643
5644                    assert!(called.load(Ordering::SeqCst));
5645                });
5646            }
5647
5648            #[test]
5649            fn invokes_all_registered_callbacks() {
5650                TOKIO_SHARED_RT.block_on(async {
5651                    let ws: Arc<WebsocketStreams> =
5652                        create_websocket_streams(Some("ws://example.com"), None);
5653                    let conn = &ws.common.connection_pool[0];
5654                    let counter = Arc::new(AtomicUsize::new(0));
5655
5656                    {
5657                        let mut st = conn.state.lock().await;
5658                        let entry = st.stream_callbacks.entry("s".into()).or_default();
5659                        let c1 = counter.clone();
5660                        entry.push(
5661                            (Box::new(move |_: &Value| {
5662                                c1.fetch_add(1, Ordering::SeqCst);
5663                            }) as Box<dyn Fn(&Value) + Send + Sync>)
5664                                .into(),
5665                        );
5666                        let c2 = counter.clone();
5667                        entry.push(
5668                            (Box::new(move |_: &Value| {
5669                                c2.fetch_add(1, Ordering::SeqCst);
5670                            }) as Box<dyn Fn(&Value) + Send + Sync>)
5671                                .into(),
5672                        );
5673                    }
5674
5675                    let msg = json!({"stream":"s","data":42}).to_string();
5676                    ws.on_message(msg, conn.clone()).await;
5677
5678                    assert_eq!(counter.load(Ordering::SeqCst), 2);
5679                });
5680            }
5681
5682            #[test]
5683            fn handles_null_data_field() {
5684                TOKIO_SHARED_RT.block_on(async {
5685                    let ws: Arc<WebsocketStreams> =
5686                        create_websocket_streams(Some("ws://example.com"), None);
5687                    let conn = &ws.common.connection_pool[0];
5688                    let called = Arc::new(AtomicUsize::new(0));
5689                    {
5690                        let mut st = conn.state.lock().await;
5691                        st.stream_callbacks.entry("n".into()).or_default().push(
5692                            (Box::new({
5693                                let c = called.clone();
5694                                move |data: &Value| {
5695                                    if data.is_null() {
5696                                        c.fetch_add(1, Ordering::SeqCst);
5697                                    }
5698                                }
5699                            }) as Box<dyn Fn(&Value) + Send + Sync>)
5700                                .into(),
5701                        );
5702                    }
5703                    let msg = json!({"stream":"n","data":null}).to_string();
5704                    ws.on_message(msg, conn.clone()).await;
5705                    assert_eq!(called.load(Ordering::SeqCst), 1);
5706                });
5707            }
5708
5709            #[test]
5710            fn with_invalid_json_does_not_panic() {
5711                TOKIO_SHARED_RT.block_on(async {
5712                    let ws: Arc<WebsocketStreams> =
5713                        create_websocket_streams(Some("ws://example.com"), None);
5714                    let conn = &ws.common.connection_pool[0];
5715                    let bad = "not a json";
5716                    ws.on_message(bad.to_string(), conn.clone()).await;
5717                });
5718            }
5719
5720            #[test]
5721            fn without_stream_field_does_nothing() {
5722                TOKIO_SHARED_RT.block_on(async {
5723                    let ws: Arc<WebsocketStreams> =
5724                        create_websocket_streams(Some("ws://example.com"), None);
5725                    let conn = &ws.common.connection_pool[0];
5726                    let msg = json!({ "data": { "foo": 1 } }).to_string();
5727                    ws.on_message(msg, conn.clone()).await;
5728                });
5729            }
5730
5731            #[test]
5732            fn with_unregistered_stream_does_not_panic() {
5733                TOKIO_SHARED_RT.block_on(async {
5734                    let ws: Arc<WebsocketStreams> =
5735                        create_websocket_streams(Some("ws://example.com"), None);
5736                    let conn = &ws.common.connection_pool[0];
5737                    let msg = json!({
5738                        "stream": "nope",
5739                        "data": { "foo": 1 }
5740                    })
5741                    .to_string();
5742                    ws.on_message(msg, conn.clone()).await;
5743                });
5744            }
5745        }
5746
5747        mod get_reconnect_url {
5748            use super::*;
5749
5750            #[test]
5751            fn single_stream_reconnect_url() {
5752                TOKIO_SHARED_RT.block_on(async {
5753                    let ws: Arc<WebsocketStreams> =
5754                        create_websocket_streams(Some("ws://example.com"), None);
5755                    let c0 = ws.common.connection_pool[0].clone();
5756                    {
5757                        let mut map = ws.connection_streams.lock().await;
5758                        map.insert("s1".to_string(), c0.clone());
5759                    }
5760                    let url = ws.get_reconnect_url("default_url".into(), c0).await;
5761                    assert_eq!(url, "ws://example.com/stream?streams=s1");
5762                });
5763            }
5764
5765            #[test]
5766            fn multiple_streams_same_connection() {
5767                TOKIO_SHARED_RT.block_on(async {
5768                    let ws: Arc<WebsocketStreams> =
5769                        create_websocket_streams(Some("ws://example.com"), None);
5770                    let c0 = ws.common.connection_pool[0].clone();
5771                    {
5772                        let mut map = ws.connection_streams.lock().await;
5773                        map.insert("a".to_string(), c0.clone());
5774                        map.insert("b".to_string(), c0.clone());
5775                    }
5776                    let url = ws.get_reconnect_url("default_url".into(), c0).await;
5777                    let suffix = url
5778                        .strip_prefix("ws://example.com/stream?streams=")
5779                        .unwrap();
5780                    let parts: Vec<_> = suffix.split('&').next().unwrap().split('/').collect();
5781                    let set = parts.into_iter().collect::<std::collections::HashSet<_>>();
5782                    assert_eq!(set, ["a", "b"].iter().copied().collect());
5783                });
5784            }
5785
5786            #[test]
5787            fn reconnect_url_with_time_unit() {
5788                TOKIO_SHARED_RT.block_on(async {
5789                    let mut ws: Arc<WebsocketStreams> =
5790                        create_websocket_streams(Some("ws://example.com"), None);
5791                    Arc::get_mut(&mut ws).unwrap().configuration.time_unit =
5792                        Some(TimeUnit::Microsecond);
5793                    let c0 = ws.common.connection_pool[0].clone();
5794                    {
5795                        let mut map = ws.connection_streams.lock().await;
5796                        map.insert("x".to_string(), c0.clone());
5797                    }
5798                    let url = ws.get_reconnect_url("default_url".into(), c0).await;
5799                    assert_eq!(
5800                        url,
5801                        "ws://example.com/stream?streams=x&timeUnit=microsecond"
5802                    );
5803                });
5804            }
5805        }
5806    }
5807
5808    mod websocket_stream {
5809        use super::*;
5810
5811        mod on {
5812            use super::*;
5813
5814            #[test]
5815            fn registers_callback_and_stream_callback_for_websocket_streams() {
5816                TOKIO_SHARED_RT.block_on(async {
5817                    let ws_base = create_websocket_streams(Some("example.com"), None);
5818                    let stream_name = "s1".to_string();
5819                    let conn = ws_base.common.connection_pool[0].clone();
5820                    {
5821                        let mut map = ws_base.connection_streams.lock().await;
5822                        map.insert(stream_name.clone(), conn.clone());
5823                    }
5824                    {
5825                        let mut state = conn.state.lock().await;
5826                        state
5827                            .stream_callbacks
5828                            .insert(stream_name.clone(), Vec::new());
5829                    }
5830                    let stream = Arc::new(WebsocketStream::<Value> {
5831                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
5832                        stream_or_id: stream_name.clone(),
5833                        callback: Mutex::new(None),
5834                        id: None,
5835                        _phantom: PhantomData,
5836                    });
5837                    let called = Arc::new(Mutex::new(false));
5838                    let called_clone = called.clone();
5839                    stream
5840                        .on("message", move |v: Value| {
5841                            let mut lock = called_clone.blocking_lock();
5842                            *lock = v == Value::String("x".into());
5843                        })
5844                        .await;
5845                    let cb_guard = stream.callback.lock().await;
5846                    assert!(cb_guard.is_some());
5847                    let cbs = {
5848                        let state = conn.state.lock().await;
5849                        state.stream_callbacks.get(&stream_name).unwrap().clone()
5850                    };
5851                    assert_eq!(cbs.len(), 1);
5852                });
5853            }
5854
5855            #[test]
5856            fn message_twice_registers_two_wrappers_for_websocket_streams() {
5857                TOKIO_SHARED_RT.block_on(async {
5858                    let ws_base = create_websocket_streams(Some("example.com"), None);
5859                    let stream_name = "s2".to_string();
5860                    let conn = ws_base.common.connection_pool[0].clone();
5861                    {
5862                        let mut map = ws_base.connection_streams.lock().await;
5863                        map.insert(stream_name.clone(), conn.clone());
5864                    }
5865                    {
5866                        let mut state = conn.state.lock().await;
5867                        state
5868                            .stream_callbacks
5869                            .insert(stream_name.clone(), Vec::new());
5870                    }
5871                    let stream = Arc::new(WebsocketStream::<Value> {
5872                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
5873                        stream_or_id: stream_name.clone(),
5874                        callback: Mutex::new(None),
5875                        id: None,
5876                        _phantom: PhantomData,
5877                    });
5878                    stream.on("message", |_| {}).await;
5879                    stream.on("message", |_| {}).await;
5880                    let state = conn.state.lock().await;
5881                    let callbacks = state.stream_callbacks.get(&stream_name).unwrap();
5882                    assert_eq!(callbacks.len(), 2);
5883                });
5884            }
5885
5886            #[test]
5887            fn ignores_non_message_event_for_websocket_streams() {
5888                TOKIO_SHARED_RT.block_on(async {
5889                    let ws_base = create_websocket_streams(Some("example.com"), None);
5890                    let stream = Arc::new(WebsocketStream::<Value> {
5891                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
5892                        stream_or_id: "s".into(),
5893                        callback: Mutex::new(None),
5894                        id: None,
5895                        _phantom: PhantomData,
5896                    });
5897                    stream.on("open", |_| {}).await;
5898                    let guard = stream.callback.lock().await;
5899                    assert!(guard.is_none());
5900                });
5901            }
5902
5903            #[test]
5904            fn registers_callback_and_stream_callback_for_websocket_api() {
5905                TOKIO_SHARED_RT.block_on(async {
5906                    let ws_base = create_websocket_api(None);
5907
5908                    {
5909                        let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
5910                        stream_callbacks.insert("id1".to_string(), Vec::new());
5911                    }
5912
5913                    let stream = Arc::new(WebsocketStream::<Value> {
5914                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
5915                        stream_or_id: "id1".to_string(),
5916                        callback: Mutex::new(None),
5917                        id: None,
5918                        _phantom: PhantomData,
5919                    });
5920
5921                    let called = Arc::new(Mutex::new(false));
5922                    let called_clone = called.clone();
5923                    stream
5924                        .on("message", move |v: Value| {
5925                            let mut lock = called_clone.blocking_lock();
5926                            *lock = v == Value::String("x".into());
5927                        })
5928                        .await;
5929
5930                    let cb_guard = stream.callback.lock().await;
5931                    assert!(cb_guard.is_some());
5932
5933                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
5934                    let callbacks = stream_callbacks.get("id1").unwrap();
5935                    assert_eq!(callbacks.len(), 1);
5936                });
5937            }
5938
5939            #[test]
5940            fn message_twice_registers_two_wrappers_for_websocket_api() {
5941                TOKIO_SHARED_RT.block_on(async {
5942                    let ws_base = create_websocket_api(None);
5943
5944                    let stream = Arc::new(WebsocketStream::<Value> {
5945                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
5946                        stream_or_id: "id2".to_string(),
5947                        callback: Mutex::new(None),
5948                        id: None,
5949                        _phantom: PhantomData,
5950                    });
5951
5952                    stream.on("message", |_| {}).await;
5953                    stream.on("message", |_| {}).await;
5954
5955                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
5956                    let callbacks = stream_callbacks.get("id2").unwrap();
5957                    assert_eq!(callbacks.len(), 2);
5958                });
5959            }
5960
5961            #[test]
5962            fn ignores_non_message_event_for_websocket_api() {
5963                TOKIO_SHARED_RT.block_on(async {
5964                    let ws_base = create_websocket_api(None);
5965
5966                    let stream = Arc::new(WebsocketStream::<Value> {
5967                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
5968                        stream_or_id: "id3".into(),
5969                        callback: Mutex::new(None),
5970                        id: None,
5971                        _phantom: PhantomData,
5972                    });
5973
5974                    stream.on("open", |_| {}).await;
5975
5976                    let guard = stream.callback.lock().await;
5977                    assert!(guard.is_none());
5978
5979                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
5980                    assert!(stream_callbacks.get("id3").is_none());
5981                    assert!(stream_callbacks.is_empty());
5982                });
5983            }
5984        }
5985
5986        mod on_message {
5987            use super::*;
5988
5989            #[test]
5990            fn on_message_registers_callback_for_websocket_streams() {
5991                TOKIO_SHARED_RT.block_on(async {
5992                    let ws_base = create_websocket_streams(Some("example.com"), None);
5993                    let stream_name = "s".to_string();
5994                    let conn = ws_base.common.connection_pool[0].clone();
5995                    {
5996                        let mut map = ws_base.connection_streams.lock().await;
5997                        map.insert(stream_name.clone(), conn.clone());
5998                    }
5999                    {
6000                        let mut state = conn.state.lock().await;
6001                        state
6002                            .stream_callbacks
6003                            .insert(stream_name.clone(), Vec::new());
6004                    }
6005                    let stream = Arc::new(WebsocketStream::<Value> {
6006                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
6007                        stream_or_id: stream_name.clone(),
6008                        callback: Mutex::new(None),
6009                        id: None,
6010                        _phantom: PhantomData,
6011                    });
6012                    stream.on_message(|_v| {});
6013                    let callbacks = &conn.state.lock().await.stream_callbacks[&stream_name];
6014                    assert_eq!(callbacks.len(), 1);
6015                });
6016            }
6017
6018            #[test]
6019            fn on_message_twice_registers_two_callbacks_for_websocket_streams() {
6020                TOKIO_SHARED_RT.block_on(async {
6021                    let ws_base = create_websocket_streams(Some("example.com"), None);
6022                    let stream_name = "s".to_string();
6023                    let conn = ws_base.common.connection_pool[0].clone();
6024                    {
6025                        let mut map = ws_base.connection_streams.lock().await;
6026                        map.insert(stream_name.clone(), conn.clone());
6027                    }
6028                    {
6029                        let mut state = conn.state.lock().await;
6030                        state
6031                            .stream_callbacks
6032                            .insert(stream_name.clone(), Vec::new());
6033                    }
6034                    let stream = Arc::new(WebsocketStream::<Value> {
6035                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
6036                        stream_or_id: stream_name.clone(),
6037                        callback: Mutex::new(None),
6038                        id: None,
6039                        _phantom: PhantomData,
6040                    });
6041                    stream.on_message(|_v| {});
6042                    stream.on_message(|_v| {});
6043                    let callbacks = &conn.state.lock().await.stream_callbacks[&stream_name];
6044                    assert_eq!(callbacks.len(), 2);
6045                });
6046            }
6047
6048            #[test]
6049            fn on_message_registers_callback_for_websocket_api() {
6050                TOKIO_SHARED_RT.block_on(async {
6051                    let ws_base = create_websocket_api(None);
6052                    let identifier = "id1".to_string();
6053
6054                    let stream = Arc::new(WebsocketStream::<Value> {
6055                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
6056                        stream_or_id: identifier.clone(),
6057                        callback: Mutex::new(None),
6058                        id: None,
6059                        _phantom: PhantomData,
6060                    });
6061
6062                    stream.on_message(|_v: Value| {});
6063
6064                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
6065                    let callbacks = stream_callbacks.get(&identifier).unwrap();
6066                    assert_eq!(callbacks.len(), 1);
6067                });
6068            }
6069
6070            #[test]
6071            fn on_message_twice_registers_two_callbacks_for_websocket_api() {
6072                TOKIO_SHARED_RT.block_on(async {
6073                    let ws_base = create_websocket_api(None);
6074                    let identifier = "id2".to_string();
6075
6076                    let stream = Arc::new(WebsocketStream::<Value> {
6077                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
6078                        stream_or_id: identifier.clone(),
6079                        callback: Mutex::new(None),
6080                        id: None,
6081                        _phantom: PhantomData,
6082                    });
6083
6084                    stream.on_message(|_v: Value| {});
6085                    stream.on_message(|_v: Value| {});
6086
6087                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
6088                    let callbacks = stream_callbacks.get(&identifier).unwrap();
6089                    assert_eq!(callbacks.len(), 2);
6090                });
6091            }
6092        }
6093
6094        mod unsubscribe {
6095            use super::*;
6096
6097            #[test]
6098            fn without_callback_does_nothing() {
6099                TOKIO_SHARED_RT.block_on(async {
6100                    let ws_base = create_websocket_streams(Some("example.com"), None);
6101                    let stream_name = "s1".to_string();
6102                    let conn = ws_base.common.connection_pool[0].clone();
6103                    {
6104                        let mut map = ws_base.connection_streams.lock().await;
6105                        map.insert(stream_name.clone(), conn.clone());
6106                    }
6107                    let mut state = conn.state.lock().await;
6108                    state.stream_callbacks.insert(stream_name.clone(), vec![]);
6109                    drop(state);
6110                    let stream = Arc::new(WebsocketStream::<Value> {
6111                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
6112                        stream_or_id: stream_name.clone(),
6113                        callback: Mutex::new(None),
6114                        id: None,
6115                        _phantom: PhantomData,
6116                    });
6117                    stream.unsubscribe().await;
6118                    let state = conn.state.lock().await;
6119                    assert!(state.stream_callbacks.contains_key(&stream_name));
6120                });
6121            }
6122
6123            #[test]
6124            fn removes_registered_callback_and_clears_state() {
6125                TOKIO_SHARED_RT.block_on(async {
6126                    let ws_base = create_websocket_streams(Some("example.com"), None);
6127                    let stream_name = "s2".to_string();
6128                    let conn = ws_base.common.connection_pool[0].clone();
6129                    {
6130                        let mut map = ws_base.connection_streams.lock().await;
6131                        map.insert(stream_name.clone(), conn.clone());
6132                    }
6133                    {
6134                        let mut state = conn.state.lock().await;
6135                        state
6136                            .stream_callbacks
6137                            .insert(stream_name.clone(), Vec::new());
6138                    }
6139                    let stream = Arc::new(WebsocketStream::<Value> {
6140                        websocket_base: WebsocketBase::WebsocketStreams(ws_base.clone()),
6141                        stream_or_id: stream_name.clone(),
6142                        callback: Mutex::new(None),
6143                        id: None,
6144                        _phantom: PhantomData,
6145                    });
6146                    stream.on("message", |_| {}).await;
6147                    {
6148                        let guard = stream.callback.lock().await;
6149                        assert!(guard.is_some());
6150                    }
6151                    stream.unsubscribe().await;
6152                    sleep(Duration::from_millis(10)).await;
6153                    let guard = stream.callback.lock().await;
6154                    assert!(guard.is_none());
6155                    let state = conn.state.lock().await;
6156                    assert!(
6157                        state
6158                            .stream_callbacks
6159                            .get(&stream_name)
6160                            .is_none_or(std::vec::Vec::is_empty)
6161                    );
6162                });
6163            }
6164
6165            #[test]
6166            fn without_callback_does_nothing_for_websocket_api() {
6167                TOKIO_SHARED_RT.block_on(async {
6168                    let ws_base = create_websocket_api(None);
6169                    let identifier = "id1".to_string();
6170
6171                    {
6172                        let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
6173                        stream_callbacks.insert(identifier.clone(), Vec::new());
6174                    }
6175
6176                    let stream = Arc::new(WebsocketStream::<Value> {
6177                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
6178                        stream_or_id: identifier.clone(),
6179                        callback: Mutex::new(None),
6180                        id: None,
6181                        _phantom: PhantomData,
6182                    });
6183
6184                    stream.unsubscribe().await;
6185
6186                    let stream_callbacks = ws_base.stream_callbacks.lock().await;
6187                    assert!(stream_callbacks.contains_key(&identifier));
6188                    let callbacks = stream_callbacks.get(&identifier).unwrap();
6189                    assert!(callbacks.is_empty());
6190                });
6191            }
6192
6193            #[test]
6194            fn removes_registered_callback_and_clears_state_for_websocket_api() {
6195                TOKIO_SHARED_RT.block_on(async {
6196                    let ws_base = create_websocket_api(None);
6197                    let identifier = "id2".to_string();
6198
6199                    {
6200                        let mut stream_callbacks = ws_base.stream_callbacks.lock().await;
6201                        stream_callbacks.insert(identifier.clone(), Vec::new());
6202                    }
6203
6204                    let stream = Arc::new(WebsocketStream::<Value> {
6205                        websocket_base: WebsocketBase::WebsocketApi(ws_base.clone()),
6206                        stream_or_id: identifier.clone(),
6207                        callback: Mutex::new(None),
6208                        id: None,
6209                        _phantom: PhantomData,
6210                    });
6211
6212                    stream.on("message", |_| {}).await;
6213
6214                    {
6215                        let stream_callbacks = ws_base.stream_callbacks.lock().await;
6216                        let callbacks = stream_callbacks
6217                            .get(&identifier)
6218                            .expect("Entry for 'id2' should exist");
6219                        assert_eq!(callbacks.len(), 1);
6220                    }
6221
6222                    stream.unsubscribe().await;
6223
6224                    {
6225                        let guard = stream.callback.lock().await;
6226                        assert!(guard.is_none());
6227                    }
6228
6229                    {
6230                        let stream_callbacks = ws_base.stream_callbacks.lock().await;
6231                        let callbacks = stream_callbacks
6232                            .get(&identifier)
6233                            .expect("Entry for 'id2' should still exist");
6234                        assert!(callbacks.is_empty());
6235                    }
6236                });
6237            }
6238        }
6239    }
6240
6241    mod create_stream_handler {
6242        use super::*;
6243
6244        #[test]
6245        fn create_stream_handler_without_id_registers_stream() {
6246            TOKIO_SHARED_RT.block_on(async {
6247                let ws = create_websocket_streams(Some("ws://example.com"), None);
6248                let stream_name = "foo".to_string();
6249                let handler = create_stream_handler::<serde_json::Value>(
6250                    WebsocketBase::WebsocketStreams(ws.clone()),
6251                    stream_name.clone(),
6252                    None,
6253                )
6254                .await;
6255                assert_eq!(handler.stream_or_id, stream_name);
6256                assert!(handler.id.is_none());
6257                let map = ws.connection_streams.lock().await;
6258                assert!(map.contains_key(&stream_name));
6259            });
6260        }
6261
6262        #[test]
6263        fn create_stream_handler_with_custom_id_registers_stream_and_id() {
6264            TOKIO_SHARED_RT.block_on(async {
6265                let ws = create_websocket_streams(Some("ws://example.com"), None);
6266                let stream_name = "bar".to_string();
6267                let custom_id = Some("my-custom-id".to_string());
6268                let handler = create_stream_handler::<serde_json::Value>(
6269                    WebsocketBase::WebsocketStreams(ws.clone()),
6270                    stream_name.clone(),
6271                    custom_id.clone(),
6272                )
6273                .await;
6274                assert_eq!(handler.stream_or_id, stream_name);
6275                assert_eq!(handler.id, custom_id);
6276                let map = ws.connection_streams.lock().await;
6277                assert!(map.contains_key(&stream_name));
6278            });
6279        }
6280
6281        #[test]
6282        fn create_stream_handler_without_id_registers_api_stream() {
6283            TOKIO_SHARED_RT.block_on(async {
6284                let ws_base = create_websocket_api(None);
6285                let identifier = "foo-api".to_string();
6286
6287                let handler = create_stream_handler::<Value>(
6288                    WebsocketBase::WebsocketApi(ws_base.clone()),
6289                    identifier.clone(),
6290                    None,
6291                )
6292                .await;
6293
6294                assert_eq!(handler.stream_or_id, identifier);
6295                assert!(handler.id.is_none());
6296            });
6297        }
6298
6299        #[test]
6300        fn create_stream_handler_with_custom_id_registers_api_stream_and_id() {
6301            TOKIO_SHARED_RT.block_on(async {
6302                let ws_base = create_websocket_api(None);
6303                let identifier = "bar-api".to_string();
6304                let custom_id = Some("custom-123".to_string());
6305
6306                let handler = create_stream_handler::<Value>(
6307                    WebsocketBase::WebsocketApi(ws_base.clone()),
6308                    identifier.clone(),
6309                    custom_id.clone(),
6310                )
6311                .await;
6312
6313                assert_eq!(handler.stream_or_id, identifier);
6314                assert_eq!(handler.id, custom_id);
6315            });
6316        }
6317    }
6318}