Skip to main content

lightcone_sdk/websocket/
client.rs

1//! Main WebSocket client implementation.
2//!
3//! Provides a WebSocket client for real-time data streaming.
4
5use std::collections::HashMap;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::task::{Context, Poll};
9use std::time::Duration;
10
11use rand::Rng;
12
13use futures_util::stream::{SplitSink, SplitStream};
14use futures_util::{SinkExt, Stream, StreamExt};
15use pin_project_lite::pin_project;
16use tokio::net::TcpStream;
17use tokio::sync::{mpsc, RwLock};
18use tokio::time::{interval, Instant};
19use tokio_tungstenite::tungstenite::client::IntoClientRequest;
20use tokio_tungstenite::tungstenite::protocol::CloseFrame;
21use tokio_tungstenite::tungstenite::Message;
22use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
23
24use ed25519_dalek::SigningKey;
25
26use crate::websocket::auth::{authenticate, AuthCredentials};
27
28use crate::websocket::error::{WebSocketError, WsResult};
29use crate::websocket::handlers::MessageHandler;
30use crate::websocket::state::price::PriceHistoryKey;
31use crate::websocket::state::{LocalOrderbook, PriceHistory, UserState};
32use crate::websocket::subscriptions::SubscriptionManager;
33use crate::websocket::types::{SubscribeParams, WsEvent, WsRequest};
34
35type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
36type WsSink = SplitSink<WsStream, Message>;
37type WsSource = SplitStream<WsStream>;
38
39/// Default WebSocket URL for Lightcone
40pub const DEFAULT_WS_URL: &str = "wss://ws.lightcone.xyz/ws";
41
42/// Connection timeout duration for WebSocket connections
43const CONNECTION_TIMEOUT: Duration = Duration::from_secs(30);
44
45/// WebSocket client configuration
46#[derive(Debug, Clone)]
47pub struct WebSocketConfig {
48    /// Number of reconnect attempts before giving up
49    pub reconnect_attempts: u32,
50    /// Base delay for exponential backoff (ms)
51    pub base_delay_ms: u64,
52    /// Maximum delay for exponential backoff (ms)
53    pub max_delay_ms: u64,
54    /// Interval for client ping (seconds)
55    pub ping_interval_secs: u64,
56    /// Timeout for pong response (seconds). Connection is considered dead if no pong received within this time.
57    pub pong_timeout_secs: u64,
58    /// Whether to automatically reconnect on disconnect
59    pub auto_reconnect: bool,
60    /// Whether to automatically re-subscribe after reconnect
61    pub auto_resubscribe: bool,
62    /// Optional authentication token for private user streams
63    pub auth_token: Option<String>,
64    /// Capacity of the event channel. Default: 1000
65    pub event_channel_capacity: usize,
66    /// Capacity of the command channel. Default: 100
67    pub command_channel_capacity: usize,
68}
69
70impl Default for WebSocketConfig {
71    fn default() -> Self {
72        Self {
73            reconnect_attempts: 10,
74            base_delay_ms: 1000,
75            max_delay_ms: 30000,
76            ping_interval_secs: 30,
77            pong_timeout_secs: 60,
78            auto_reconnect: true,
79            auto_resubscribe: true,
80            auth_token: None,
81            event_channel_capacity: 1000,
82            command_channel_capacity: 100,
83        }
84    }
85}
86
87/// Connection state
88#[derive(Debug, Clone, Copy, PartialEq, Eq)]
89pub enum ConnectionState {
90    Disconnected,
91    Connecting,
92    Connected,
93    Reconnecting,
94    Disconnecting,
95}
96
97/// Internal command for the connection task
98enum ConnectionCommand {
99    Send(String),
100    Disconnect,
101    Ping,
102}
103
104pin_project! {
105    /// Main WebSocket client for Lightcone
106    ///
107    /// # Example
108    ///
109    /// ```ignore
110    /// use lightcone_sdk::websocket::*;
111    /// use futures_util::StreamExt;
112    ///
113    /// #[tokio::main]
114    /// async fn main() -> Result<(), WebSocketError> {
115    ///     let mut client = LightconeWebSocketClient::connect("ws://api.lightcone.xyz:8081/ws").await?;
116    ///
117    ///     client.subscribe_book_updates(vec!["market1:ob1".to_string()]).await?;
118    ///
119    ///     while let Some(event) = client.next().await {
120    ///         match event {
121    ///             WsEvent::BookUpdate { orderbook_id, is_snapshot } => {
122    ///                 if let Some(book) = client.get_orderbook(&orderbook_id) {
123    ///                     println!("Best bid: {:?}", book.best_bid());
124    ///                 }
125    ///             }
126    ///             _ => {}
127    ///         }
128    ///     }
129    ///     Ok(())
130    /// }
131    /// ```
132    pub struct LightconeWebSocketClient {
133        url: String,
134        config: WebSocketConfig,
135        state: ConnectionState,
136        subscriptions: Arc<RwLock<SubscriptionManager>>,
137        orderbooks: Arc<RwLock<HashMap<String, LocalOrderbook>>>,
138        user_states: Arc<RwLock<HashMap<String, UserState>>>,
139        price_histories: Arc<RwLock<HashMap<PriceHistoryKey, PriceHistory>>>,
140        subscribed_user: Arc<RwLock<Option<String>>>,
141        handler: Arc<MessageHandler>,
142        cmd_tx: Option<mpsc::Sender<ConnectionCommand>>,
143        #[pin]
144        event_rx: mpsc::Receiver<WsEvent>,
145        event_tx: mpsc::Sender<WsEvent>,
146        reconnect_attempt: u32,
147        connection_task_handle: Option<tokio::task::JoinHandle<()>>,
148        auth_credentials: Option<AuthCredentials>,
149    }
150}
151
152impl LightconeWebSocketClient {
153    /// Connect to the default Lightcone WebSocket server.
154    ///
155    /// Uses the URL `wss://ws.lightcone.xyz/ws`.
156    ///
157    /// # Example
158    ///
159    /// ```ignore
160    /// let client = LightconeWebSocketClient::connect_default().await?;
161    /// client.subscribe_book_updates(vec!["ob1".to_string()]).await?;
162    /// ```
163    pub async fn connect_default() -> WsResult<Self> {
164        Self::connect_with_config(DEFAULT_WS_URL, WebSocketConfig::default()).await
165    }
166
167    /// Connect to a WebSocket server with default configuration
168    pub async fn connect(url: &str) -> WsResult<Self> {
169        Self::connect_with_config(url, WebSocketConfig::default()).await
170    }
171
172    /// Connect to the default Lightcone WebSocket server with authentication.
173    ///
174    /// This method:
175    /// 1. Authenticates with the Lightcone API using the provided signing key
176    /// 2. Obtains an auth token
177    /// 3. Connects to the WebSocket server with the auth token
178    ///
179    /// # Arguments
180    ///
181    /// * `signing_key` - The Ed25519 signing key for authentication
182    ///
183    /// # Example
184    ///
185    /// ```ignore
186    /// use ed25519_dalek::SigningKey;
187    ///
188    /// let signing_key = SigningKey::from_bytes(&secret_key_bytes);
189    /// let client = LightconeWebSocketClient::connect_authenticated(&signing_key).await?;
190    /// client.subscribe_user(pubkey.to_string()).await?;
191    /// ```
192    pub async fn connect_authenticated(signing_key: &SigningKey) -> WsResult<Self> {
193        Self::connect_authenticated_with_config(signing_key, WebSocketConfig::default()).await
194    }
195
196    /// Connect to the default Lightcone WebSocket server with authentication and custom config.
197    pub async fn connect_authenticated_with_config(
198        signing_key: &SigningKey,
199        mut config: WebSocketConfig,
200    ) -> WsResult<Self> {
201        // Authenticate and get credentials
202        let credentials = authenticate(signing_key).await?;
203        config.auth_token = Some(credentials.auth_token.clone());
204
205        Self::connect_with_config_and_credentials(
206            DEFAULT_WS_URL,
207            config,
208            Some(credentials),
209        )
210        .await
211    }
212
213    /// Connect to a WebSocket server with a pre-obtained auth token.
214    ///
215    /// Use this if you already have an auth token from a previous authentication.
216    ///
217    /// # Arguments
218    ///
219    /// * `url` - The WebSocket URL to connect to
220    /// * `auth_token` - The auth token obtained from authentication
221    pub async fn connect_with_auth(url: &str, auth_token: String) -> WsResult<Self> {
222        let trimmed = auth_token.trim();
223        if trimmed.is_empty() {
224            return Err(WebSocketError::InvalidAuthToken(
225                "Auth token cannot be empty".to_string()
226            ));
227        }
228        let config = WebSocketConfig {
229            auth_token: Some(trimmed.to_string()),
230            ..Default::default()
231        };
232        Self::connect_with_config(url, config).await
233    }
234
235    /// Connect to a WebSocket server with custom configuration
236    pub async fn connect_with_config(url: &str, config: WebSocketConfig) -> WsResult<Self> {
237        Self::connect_with_config_and_credentials(url, config, None).await
238    }
239
240    /// Internal method to connect with config and optional credentials
241    async fn connect_with_config_and_credentials(
242        url: &str,
243        config: WebSocketConfig,
244        auth_credentials: Option<AuthCredentials>,
245    ) -> WsResult<Self> {
246        let (event_tx, event_rx) = mpsc::channel(config.event_channel_capacity);
247
248        let orderbooks = Arc::new(RwLock::new(HashMap::new()));
249        let user_states = Arc::new(RwLock::new(HashMap::new()));
250        let price_histories = Arc::new(RwLock::new(HashMap::new()));
251        let subscribed_user = Arc::new(RwLock::new(None));
252        let subscriptions = Arc::new(RwLock::new(SubscriptionManager::new()));
253
254        let handler = Arc::new(MessageHandler::new(
255            orderbooks.clone(),
256            user_states.clone(),
257            price_histories.clone(),
258            subscribed_user.clone(),
259        ));
260
261        let mut client = Self {
262            url: url.to_string(),
263            config,
264            state: ConnectionState::Disconnected,
265            subscriptions,
266            orderbooks,
267            user_states,
268            price_histories,
269            subscribed_user,
270            handler,
271            cmd_tx: None,
272            event_rx,
273            event_tx,
274            reconnect_attempt: 0,
275            connection_task_handle: None,
276            auth_credentials,
277        };
278
279        client.establish_connection().await?;
280        Ok(client)
281    }
282
283    /// Establish the WebSocket connection
284    async fn establish_connection(&mut self) -> WsResult<()> {
285        self.state = ConnectionState::Connecting;
286
287        // Build the WebSocket request, optionally with auth cookie
288        let ws_stream = if let Some(ref auth_token) = self.config.auth_token {
289            let mut request = self
290                .url
291                .as_str()
292                .into_client_request()
293                .map_err(|e| WebSocketError::InvalidUrl(e.to_string()))?;
294
295            request.headers_mut().insert(
296                "Cookie",
297                format!("auth_token={}", auth_token)
298                    .parse()
299                    .map_err(|e| WebSocketError::Protocol(format!("Invalid cookie header: {}", e)))?,
300            );
301
302            let (stream, _) = tokio::time::timeout(CONNECTION_TIMEOUT, connect_async(request))
303                .await
304                .map_err(|_| WebSocketError::Timeout)?
305                .map_err(WebSocketError::from)?;
306            stream
307        } else {
308            let (stream, _) = tokio::time::timeout(CONNECTION_TIMEOUT, connect_async(&self.url))
309                .await
310                .map_err(|_| WebSocketError::Timeout)?
311                .map_err(WebSocketError::from)?;
312            stream
313        };
314
315        self.state = ConnectionState::Connected;
316        self.reconnect_attempt = 0;
317
318        let (sink, source) = ws_stream.split();
319        let (cmd_tx, cmd_rx) = mpsc::channel(self.config.command_channel_capacity);
320        self.cmd_tx = Some(cmd_tx);
321
322        // Spawn the connection task
323        let ctx = ConnectionContext {
324            handler: self.handler.clone(),
325            event_tx: self.event_tx.clone(),
326            config: self.config.clone(),
327            subscriptions: self.subscriptions.clone(),
328            url: self.url.clone(),
329        };
330
331        let handle = tokio::spawn(connection_task(sink, source, cmd_rx, ctx));
332        self.connection_task_handle = Some(handle);
333
334        // Send connected event
335        let _ = self.event_tx.send(WsEvent::Connected).await;
336
337        Ok(())
338    }
339
340    /// Subscribe to orderbook updates
341    pub async fn subscribe_book_updates(&mut self, orderbook_ids: Vec<String>) -> WsResult<()> {
342        // Initialize state for each orderbook
343        for id in &orderbook_ids {
344            self.handler.init_orderbook(id).await;
345        }
346
347        // Track subscription
348        self.subscriptions.write().await.add_book_update(orderbook_ids.clone());
349
350        // Send subscribe request
351        let params = SubscribeParams::book_update(orderbook_ids);
352        self.send_subscribe(params).await
353    }
354
355    /// Subscribe to trade executions
356    pub async fn subscribe_trades(&mut self, orderbook_ids: Vec<String>) -> WsResult<()> {
357        self.subscriptions.write().await.add_trades(orderbook_ids.clone());
358        let params = SubscribeParams::trades(orderbook_ids);
359        self.send_subscribe(params).await
360    }
361
362    /// Subscribe to user events
363    pub async fn subscribe_user(&mut self, user: String) -> WsResult<()> {
364        self.handler.init_user_state(&user).await;
365        self.subscriptions.write().await.add_user(user.clone());
366        let params = SubscribeParams::user(user);
367        self.send_subscribe(params).await
368    }
369
370    /// Subscribe to price history
371    pub async fn subscribe_price_history(
372        &mut self,
373        orderbook_id: String,
374        resolution: String,
375        include_ohlcv: bool,
376    ) -> WsResult<()> {
377        self.handler
378            .init_price_history(&orderbook_id, &resolution, include_ohlcv)
379            .await;
380        self.subscriptions
381            .write()
382            .await
383            .add_price_history(orderbook_id.clone(), resolution.clone(), include_ohlcv);
384        let params = SubscribeParams::price_history(orderbook_id, resolution, include_ohlcv);
385        self.send_subscribe(params).await
386    }
387
388    /// Subscribe to market events
389    pub async fn subscribe_market(&mut self, market_pubkey: String) -> WsResult<()> {
390        self.subscriptions.write().await.add_market(market_pubkey.clone());
391        let params = SubscribeParams::market(market_pubkey);
392        self.send_subscribe(params).await
393    }
394
395    /// Unsubscribe from orderbook updates
396    pub async fn unsubscribe_book_updates(&mut self, orderbook_ids: Vec<String>) -> WsResult<()> {
397        self.subscriptions.write().await.remove_book_update(&orderbook_ids);
398        let params = SubscribeParams::book_update(orderbook_ids);
399        self.send_unsubscribe(params).await
400    }
401
402    /// Unsubscribe from trades
403    pub async fn unsubscribe_trades(&mut self, orderbook_ids: Vec<String>) -> WsResult<()> {
404        self.subscriptions.write().await.remove_trades(&orderbook_ids);
405        let params = SubscribeParams::trades(orderbook_ids);
406        self.send_unsubscribe(params).await
407    }
408
409    /// Unsubscribe from user events
410    pub async fn unsubscribe_user(&mut self, user: String) -> WsResult<()> {
411        self.handler.clear_subscribed_user(&user).await;
412        self.subscriptions.write().await.remove_user(&user);
413        let params = SubscribeParams::user(user);
414        self.send_unsubscribe(params).await
415    }
416
417    /// Unsubscribe from price history
418    pub async fn unsubscribe_price_history(
419        &mut self,
420        orderbook_id: String,
421        resolution: String,
422    ) -> WsResult<()> {
423        self.subscriptions
424            .write()
425            .await
426            .remove_price_history(&orderbook_id, &resolution);
427        let params = SubscribeParams::price_history(orderbook_id, resolution, false);
428        self.send_unsubscribe(params).await
429    }
430
431    /// Unsubscribe from market events
432    pub async fn unsubscribe_market(&mut self, market_pubkey: String) -> WsResult<()> {
433        self.subscriptions.write().await.remove_market(&market_pubkey);
434        let params = SubscribeParams::market(market_pubkey);
435        self.send_unsubscribe(params).await
436    }
437
438    /// Send a subscribe request
439    async fn send_subscribe(&self, params: SubscribeParams) -> WsResult<()> {
440        let request = WsRequest::subscribe(params);
441        self.send_json(&request).await
442    }
443
444    /// Send an unsubscribe request
445    async fn send_unsubscribe(&self, params: SubscribeParams) -> WsResult<()> {
446        let request = WsRequest::unsubscribe(params);
447        self.send_json(&request).await
448    }
449
450    /// Send a ping request
451    pub async fn ping(&mut self) -> WsResult<()> {
452        if let Some(tx) = &self.cmd_tx {
453            tx.send(ConnectionCommand::Ping)
454                .await
455                .map_err(|_| WebSocketError::ChannelClosed)?;
456        }
457        Ok(())
458    }
459
460    /// Send a JSON message
461    async fn send_json<T: serde::Serialize>(&self, msg: &T) -> WsResult<()> {
462        let json = serde_json::to_string(msg)?;
463        self.send_text(json).await
464    }
465
466    /// Send a text message
467    async fn send_text(&self, text: String) -> WsResult<()> {
468        if let Some(tx) = &self.cmd_tx {
469            tx.send(ConnectionCommand::Send(text))
470                .await
471                .map_err(|_| WebSocketError::ChannelClosed)?;
472            Ok(())
473        } else {
474            Err(WebSocketError::NotConnected)
475        }
476    }
477
478    /// Disconnect from the server
479    pub async fn disconnect(&mut self) -> WsResult<()> {
480        self.state = ConnectionState::Disconnecting;
481
482        // Send disconnect command to the connection task
483        if let Some(tx) = self.cmd_tx.take() {
484            let _ = tx.send(ConnectionCommand::Disconnect).await;
485        }
486
487        // Wait for the connection task to finish
488        if let Some(handle) = self.connection_task_handle.take() {
489            let _ = handle.await;
490        }
491
492        self.state = ConnectionState::Disconnected;
493        Ok(())
494    }
495
496    /// Check if the connection task is still running
497    pub fn is_task_running(&self) -> bool {
498        self.connection_task_handle
499            .as_ref()
500            .map(|h| !h.is_finished())
501            .unwrap_or(false)
502    }
503
504    /// Get the current connection state
505    pub fn connection_state(&self) -> ConnectionState {
506        self.state
507    }
508
509    /// Check if connected
510    pub fn is_connected(&self) -> bool {
511        self.state == ConnectionState::Connected
512    }
513
514    /// Check if the connection is authenticated
515    pub fn is_authenticated(&self) -> bool {
516        self.config.auth_token.is_some()
517    }
518
519    /// Get the authentication credentials if available
520    pub fn auth_credentials(&self) -> Option<&AuthCredentials> {
521        self.auth_credentials.as_ref()
522    }
523
524    /// Get the user's public key if authenticated
525    pub fn user_pubkey(&self) -> Option<&str> {
526        self.auth_credentials.as_ref().map(|c| c.user_pubkey.as_str())
527    }
528
529    /// Get a reference to the local orderbook state.
530    pub async fn get_orderbook(&self, orderbook_id: &str) -> Option<LocalOrderbook> {
531        let orderbooks = self.orderbooks.read().await;
532        orderbooks.get(orderbook_id).cloned()
533    }
534
535    /// Get a reference to the local user state.
536    pub async fn get_user_state(&self, user: &str) -> Option<UserState> {
537        let states = self.user_states.read().await;
538        states.get(user).cloned()
539    }
540
541    /// Get a reference to the price history state.
542    pub async fn get_price_history(
543        &self,
544        orderbook_id: &str,
545        resolution: &str,
546    ) -> Option<PriceHistory> {
547        let key = PriceHistoryKey::new(orderbook_id.to_string(), resolution.to_string());
548        let histories = self.price_histories.read().await;
549        histories.get(&key).cloned()
550    }
551
552    /// Get the WebSocket URL
553    pub fn url(&self) -> &str {
554        &self.url
555    }
556
557    /// Get the configuration
558    pub fn config(&self) -> &WebSocketConfig {
559        &self.config
560    }
561}
562
563impl Stream for LightconeWebSocketClient {
564    type Item = WsEvent;
565
566    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
567        let mut this = self.project();
568        this.event_rx.poll_recv(cx)
569    }
570}
571
572/// Shared context for the connection task
573struct ConnectionContext {
574    handler: Arc<MessageHandler>,
575    event_tx: mpsc::Sender<WsEvent>,
576    config: WebSocketConfig,
577    subscriptions: Arc<RwLock<SubscriptionManager>>,
578    url: String,
579}
580
581/// Connection task that handles the WebSocket connection
582async fn connection_task(
583    mut sink: WsSink,
584    mut source: WsSource,
585    mut cmd_rx: mpsc::Receiver<ConnectionCommand>,
586    ctx: ConnectionContext,
587) {
588    let ConnectionContext {
589        handler,
590        event_tx,
591        config,
592        subscriptions,
593        url,
594    } = ctx;
595    let ping_interval_duration = Duration::from_secs(config.ping_interval_secs);
596    let pong_timeout_duration = Duration::from_secs(config.pong_timeout_secs);
597    let mut ping_interval = interval(ping_interval_duration);
598    ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
599
600    let mut reconnect_attempt = 0u32;
601    let mut last_pong = Instant::now();
602    let mut awaiting_pong = false;
603
604    loop {
605        tokio::select! {
606            // Handle incoming WebSocket messages
607            msg = source.next() => {
608                match msg {
609                    Some(Ok(Message::Text(text))) => {
610                        let events = handler.handle_message(&text).await;
611                        for event in events {
612                            // Handle JSON pong - update timeout tracking
613                            if matches!(event, WsEvent::Pong) {
614                                last_pong = Instant::now();
615                                awaiting_pong = false;
616                            }
617
618                            // Use try_send to avoid blocking the connection task if consumer is slow
619                            match event_tx.try_send(event) {
620                                Ok(_) => {}
621                                Err(mpsc::error::TrySendError::Full(dropped_event)) => {
622                                    tracing::warn!(
623                                        "Event channel full, dropping event: {:?}",
624                                        std::mem::discriminant(&dropped_event)
625                                    );
626                                }
627                                Err(mpsc::error::TrySendError::Closed(_)) => {
628                                    tracing::debug!("Event receiver dropped");
629                                    return;
630                                }
631                            }
632                        }
633                    }
634                    Some(Ok(Message::Ping(data))) => {
635                        if let Err(e) = sink.send(Message::Pong(data)).await {
636                            tracing::warn!("Failed to send pong: {}", e);
637                        }
638                    }
639                    Some(Ok(Message::Pong(_))) => {
640                        // Received pong response - update tracking
641                        last_pong = Instant::now();
642                        awaiting_pong = false;
643                        let _ = event_tx.send(WsEvent::Pong).await;
644                    }
645                    Some(Ok(Message::Close(frame))) => {
646                        let close_code: u16 = frame.as_ref().map(|f| f.code.into()).unwrap_or(0);
647                        let reason = frame
648                            .as_ref()
649                            .map(|f| format!("code: {}, reason: {}", f.code, f.reason))
650                            .unwrap_or_else(|| "no reason".to_string());
651
652                        tracing::info!("WebSocket closed: {}", reason);
653                        let _ = event_tx.send(WsEvent::Disconnected { reason: reason.clone() }).await;
654
655                        // Check if rate limited (close code 1008)
656                        if close_code == 1008 {
657                            let _ = event_tx.send(WsEvent::Error {
658                                error: WebSocketError::RateLimited,
659                            }).await;
660                        }
661
662                        // Try to reconnect if enabled
663                        if config.auto_reconnect && reconnect_attempt < config.reconnect_attempts {
664                            reconnect_attempt += 1;
665                            let _ = event_tx.send(WsEvent::Reconnecting {
666                                attempt: reconnect_attempt,
667                            }).await;
668
669                            // Full jitter: randomize between 0 and exponential delay to prevent thundering herd
670                            let max_delay = config.base_delay_ms * 2u64.pow(reconnect_attempt.saturating_sub(1));
671                            let jittered_delay = rand::thread_rng().gen_range(0..=max_delay);
672                            let delay = jittered_delay.min(config.max_delay_ms);
673                            tokio::time::sleep(Duration::from_millis(delay)).await;
674
675                            // Try to reconnect
676                            match reconnect(&url, &handler, &subscriptions, &config).await {
677                                Ok((new_sink, new_source)) => {
678                                    sink = new_sink;
679                                    source = new_source;
680                                    reconnect_attempt = 0;
681                                    last_pong = Instant::now();
682                                    awaiting_pong = false;
683                                    let _ = event_tx.send(WsEvent::Connected).await;
684                                }
685                                Err(e) => {
686                                    tracing::error!("Reconnect failed: {:?}", e);
687                                    let _ = event_tx.send(WsEvent::Error { error: e }).await;
688                                }
689                            }
690                        } else {
691                            return;
692                        }
693                    }
694                    Some(Ok(Message::Binary(_))) => {
695                        // Ignore binary messages
696                    }
697                    Some(Ok(Message::Frame(_))) => {
698                        // Ignore raw frames
699                    }
700                    Some(Err(e)) => {
701                        tracing::error!("WebSocket error: {}", e);
702                        let _ = event_tx.send(WsEvent::Error {
703                            error: WebSocketError::from(e),
704                        }).await;
705                    }
706                    None => {
707                        tracing::info!("WebSocket stream ended");
708                        let _ = event_tx.send(WsEvent::Disconnected {
709                            reason: "Stream ended".to_string(),
710                        }).await;
711                        return;
712                    }
713                }
714            }
715
716            // Handle commands from the client
717            cmd = cmd_rx.recv() => {
718                match cmd {
719                    Some(ConnectionCommand::Send(text)) => {
720                        if let Err(e) = sink.send(Message::Text(text.into())).await {
721                            tracing::warn!("Failed to send message: {}", e);
722                        }
723                    }
724                    Some(ConnectionCommand::Ping) => {
725                        let request = WsRequest::ping();
726                        if let Ok(json) = serde_json::to_string(&request) {
727                            if let Err(e) = sink.send(Message::Text(json.into())).await {
728                                tracing::warn!("Failed to send ping: {}", e);
729                            }
730                        }
731                    }
732                    Some(ConnectionCommand::Disconnect) => {
733                        let _ = sink.send(Message::Close(Some(CloseFrame {
734                            code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Normal,
735                            reason: "Client disconnect".into(),
736                        }))).await;
737                        return;
738                    }
739                    None => {
740                        // Command channel closed
741                        return;
742                    }
743                }
744            }
745
746            // Periodic ping with timeout check
747            _ = ping_interval.tick() => {
748                // Check if we're still waiting for a pong from the previous ping
749                if awaiting_pong && last_pong.elapsed() > pong_timeout_duration {
750                    tracing::warn!("Pong timeout: no response received within {:?}", pong_timeout_duration);
751                    let _ = event_tx.send(WsEvent::Error {
752                        error: WebSocketError::PingTimeout,
753                    }).await;
754
755                    // Treat this as a disconnect and try to reconnect
756                    let _ = event_tx.send(WsEvent::Disconnected {
757                        reason: "Ping timeout".to_string(),
758                    }).await;
759
760                    if config.auto_reconnect && reconnect_attempt < config.reconnect_attempts {
761                        reconnect_attempt += 1;
762                        let _ = event_tx.send(WsEvent::Reconnecting {
763                            attempt: reconnect_attempt,
764                        }).await;
765
766                        let max_delay = config.base_delay_ms * 2u64.pow(reconnect_attempt.saturating_sub(1));
767                        let jittered_delay = rand::thread_rng().gen_range(0..=max_delay);
768                        let delay = jittered_delay.min(config.max_delay_ms);
769                        tokio::time::sleep(Duration::from_millis(delay)).await;
770
771                        match reconnect(&url, &handler, &subscriptions, &config).await {
772                            Ok((new_sink, new_source)) => {
773                                sink = new_sink;
774                                source = new_source;
775                                reconnect_attempt = 0;
776                                last_pong = Instant::now();
777                                awaiting_pong = false;
778                                let _ = event_tx.send(WsEvent::Connected).await;
779                            }
780                            Err(e) => {
781                                tracing::error!("Reconnect failed: {:?}", e);
782                                let _ = event_tx.send(WsEvent::Error { error: e }).await;
783                            }
784                        }
785                    } else {
786                        return;
787                    }
788                } else {
789                    // Send ping
790                    let request = WsRequest::ping();
791                    if let Ok(json) = serde_json::to_string(&request) {
792                        if let Err(e) = sink.send(Message::Text(json.into())).await {
793                            tracing::warn!("Failed to send periodic ping: {}", e);
794                        } else {
795                            awaiting_pong = true;
796                        }
797                    }
798                }
799            }
800        }
801    }
802}
803
804/// Reconnect to the WebSocket server
805async fn reconnect(
806    url: &str,
807    handler: &Arc<MessageHandler>,
808    subscriptions: &Arc<RwLock<SubscriptionManager>>,
809    config: &WebSocketConfig,
810) -> WsResult<(WsSink, WsSource)> {
811    // Build the WebSocket request, optionally with auth cookie
812    let ws_stream = if let Some(ref auth_token) = config.auth_token {
813        let mut request = url
814            .into_client_request()
815            .map_err(|e| WebSocketError::InvalidUrl(e.to_string()))?;
816
817        request.headers_mut().insert(
818            "Cookie",
819            format!("auth_token={}", auth_token)
820                .parse()
821                .map_err(|e| WebSocketError::Protocol(format!("Invalid cookie header: {}", e)))?,
822        );
823
824        let (stream, _) = tokio::time::timeout(CONNECTION_TIMEOUT, connect_async(request))
825            .await
826            .map_err(|_| WebSocketError::Timeout)?
827            .map_err(WebSocketError::from)?;
828        stream
829    } else {
830        let (stream, _) = tokio::time::timeout(CONNECTION_TIMEOUT, connect_async(url))
831            .await
832            .map_err(|_| WebSocketError::Timeout)?
833            .map_err(WebSocketError::from)?;
834        stream
835    };
836
837    let (mut sink, source) = ws_stream.split();
838
839    // Clear state
840    handler.clear_all().await;
841
842    // Re-subscribe if enabled
843    if config.auto_resubscribe {
844        let subs = subscriptions.read().await.get_all_subscriptions();
845        for sub in subs {
846            let request = WsRequest::subscribe(sub.to_params());
847            if let Ok(json) = serde_json::to_string(&request) {
848                if let Err(e) = sink.send(Message::Text(json.into())).await {
849                    tracing::warn!("Failed to re-subscribe after reconnect: {}", e);
850                }
851            }
852        }
853    }
854
855    Ok((sink, source))
856}
857
858#[cfg(test)]
859mod tests {
860    use super::*;
861
862    #[test]
863    fn test_config_default() {
864        let config = WebSocketConfig::default();
865        assert_eq!(config.reconnect_attempts, 10);
866        assert_eq!(config.base_delay_ms, 1000);
867        assert_eq!(config.max_delay_ms, 30000);
868        assert_eq!(config.ping_interval_secs, 30);
869        assert_eq!(config.pong_timeout_secs, 60);
870        assert!(config.auto_reconnect);
871        assert!(config.auto_resubscribe);
872        assert_eq!(config.event_channel_capacity, 1000);
873        assert_eq!(config.command_channel_capacity, 100);
874    }
875
876    #[test]
877    fn test_backoff_calculation() {
878        let config = WebSocketConfig::default();
879
880        // First attempt
881        let delay = config.base_delay_ms * 2u64.pow(0);
882        assert_eq!(delay, 1000);
883
884        // Second attempt
885        let delay = config.base_delay_ms * 2u64.pow(1);
886        assert_eq!(delay, 2000);
887
888        // Third attempt
889        let delay = config.base_delay_ms * 2u64.pow(2);
890        assert_eq!(delay, 4000);
891
892        // Should cap at max
893        let delay = config.base_delay_ms * 2u64.pow(10);
894        let capped = delay.min(config.max_delay_ms);
895        assert_eq!(capped, 30000);
896    }
897}