stream_tungstenite/
client.rs

1//! WebSocket client with automatic reconnection.
2//!
3//! This module provides the main WebSocket client API that integrates:
4//! - Connection management with automatic reconnection
5//! - Customizable retry strategies
6//! - Application-level handshakes
7//! - Extension system for lifecycle and message handling
8//!
9//! # Example
10//!
11//! ```rust,ignore
12//! use stream_tungstenite::{WebSocketClient, ClientConfig};
13//!
14//! // Create a client with default configuration
15//! let client = WebSocketClient::new("wss://example.com/ws");
16//!
17//! // Subscribe to messages
18//! let mut messages = client.subscribe();
19//!
20//! // Send messages
21//! client.send(Message::Text("hello".into())).await.unwrap();
22//!
23//! // Receive messages
24//! while let Some(msg) = messages.recv().await {
25//!     println!("Received: {:?}", msg);
26//! }
27//! ```
28
29use std::sync::atomic::{AtomicBool, Ordering};
30use std::sync::Arc;
31use std::time::Duration;
32
33use futures_util::StreamExt;
34use tokio::sync::{broadcast, mpsc, watch, RwLock};
35use tokio_util::sync::CancellationToken;
36use tungstenite::protocol::WebSocketConfig;
37use tungstenite::Message;
38
39use crate::connection::ConnectionEvent;
40use crate::connection::{
41    ConnectionSnapshot, ConnectionSupervisor, DefaultConnector, ExponentialBackoff, NoRetry,
42    RetryStrategy,
43};
44use crate::error::{ClientError, DisconnectReason, HandshakeError, SendError, SupervisorError};
45use crate::extension::{Extension, ExtensionHost};
46use crate::handshake::{BoxHandshaker, Handshaker, NoOpHandshaker};
47use crate::message::{DispatcherConfig, MessageDispatcher, ProcessorErrorPolicy, SharedMessage};
48
49// (Reserved) WebSocket stream type alias (default connector's stream)
50// type WsStream = DefaultWsStream;
51
52/// Configuration for the WebSocket client
53#[derive(Clone)]
54pub struct ClientConfig {
55    /// Receive timeout - disconnect if no message received within this duration
56    pub receive_timeout: Duration,
57    /// Whether to exit immediately if the first connection fails
58    pub exit_on_first_failure: bool,
59    /// Connection timeout for establishing TCP/WebSocket
60    pub connect_timeout: Duration,
61    /// Delay before retrying after handshake failures
62    pub handshake_retry_delay: Duration,
63    /// WebSocket protocol configuration
64    pub ws_config: Option<WebSocketConfig>,
65    /// Disable Nagle's algorithm for lower latency
66    pub disable_nagle: bool,
67    /// Channel buffer size for message broadcasting
68    pub channel_buffer_size: usize,
69    /// Outgoing send queue capacity (bounded channel for backpressure)
70    pub send_queue_capacity: usize,
71    /// Policy for handling extension processor errors
72    pub processor_error_policy: ProcessorErrorPolicy,
73}
74
75impl Default for ClientConfig {
76    fn default() -> Self {
77        Self {
78            receive_timeout: Duration::from_secs(20),
79            exit_on_first_failure: false,
80            connect_timeout: Duration::from_secs(30),
81            handshake_retry_delay: Duration::from_secs(5),
82            ws_config: None,
83            disable_nagle: false,
84            channel_buffer_size: 256,
85            send_queue_capacity: 256,
86            processor_error_policy: ProcessorErrorPolicy::Ignore,
87        }
88    }
89}
90
91impl ClientConfig {
92    /// Create a new configuration with default values
93    #[must_use]
94    pub fn new() -> Self {
95        Self::default()
96    }
97
98    /// Set the receive timeout
99    #[must_use]
100    pub const fn with_receive_timeout(mut self, timeout: Duration) -> Self {
101        self.receive_timeout = timeout;
102        self
103    }
104
105    /// Set whether to exit on first connection failure
106    #[must_use]
107    pub const fn with_exit_on_first_failure(mut self, exit: bool) -> Self {
108        self.exit_on_first_failure = exit;
109        self
110    }
111
112    /// Set connect timeout
113    #[must_use]
114    pub const fn with_connect_timeout(mut self, timeout: Duration) -> Self {
115        self.connect_timeout = timeout;
116        self
117    }
118
119    /// Set handshake retry delay
120    #[must_use]
121    pub const fn with_handshake_retry_delay(mut self, delay: Duration) -> Self {
122        self.handshake_retry_delay = delay;
123        self
124    }
125
126    /// Set WebSocket protocol configuration
127    #[must_use]
128    #[allow(clippy::missing_const_for_fn)] // WebSocketConfig is not const-compatible
129    pub fn with_ws_config(mut self, config: WebSocketConfig) -> Self {
130        self.ws_config = Some(config);
131        self
132    }
133
134    /// Disable Nagle's algorithm
135    #[must_use]
136    pub const fn with_nodelay(mut self, nodelay: bool) -> Self {
137        self.disable_nagle = nodelay;
138        self
139    }
140
141    /// Set channel buffer size
142    #[must_use]
143    pub const fn with_channel_buffer(mut self, size: usize) -> Self {
144        self.channel_buffer_size = size;
145        self
146    }
147
148    /// Set bounded send queue capacity (for backpressure)
149    #[must_use]
150    pub const fn with_send_queue_capacity(mut self, cap: usize) -> Self {
151        self.send_queue_capacity = cap;
152        self
153    }
154
155    /// Set processor error handling policy for message extensions
156    #[must_use]
157    pub const fn with_processor_error_policy(mut self, policy: ProcessorErrorPolicy) -> Self {
158        self.processor_error_policy = policy;
159        self
160    }
161
162    /// Preset: Fast reconnection for low-latency scenarios
163    #[must_use]
164    pub const fn fast_reconnect() -> Self {
165        Self {
166            receive_timeout: Duration::from_secs(10),
167            exit_on_first_failure: false,
168            connect_timeout: Duration::from_secs(10),
169            handshake_retry_delay: Duration::from_millis(500),
170            ws_config: None,
171            disable_nagle: true,
172            channel_buffer_size: 512,
173            send_queue_capacity: 512,
174            processor_error_policy: ProcessorErrorPolicy::Ignore,
175        }
176    }
177
178    /// Preset: Stable connection for long-running scenarios
179    #[must_use]
180    pub const fn stable_connection() -> Self {
181        Self {
182            receive_timeout: Duration::from_secs(60),
183            exit_on_first_failure: false,
184            connect_timeout: Duration::from_secs(60),
185            handshake_retry_delay: Duration::from_secs(2),
186            ws_config: None,
187            disable_nagle: false,
188            channel_buffer_size: 128,
189            send_queue_capacity: 128,
190            processor_error_policy: ProcessorErrorPolicy::Ignore,
191        }
192    }
193}
194
195impl From<&ClientConfig> for DispatcherConfig {
196    fn from(config: &ClientConfig) -> Self {
197        Self::new()
198            .with_receive_timeout(config.receive_timeout)
199            .with_broadcast_capacity(config.channel_buffer_size)
200            .with_send_buffer_capacity(config.send_queue_capacity)
201            .with_processor_error_policy(config.processor_error_policy)
202    }
203}
204
205/// Sender handle for sending messages
206#[derive(Clone)]
207pub struct Sender {
208    tx: mpsc::Sender<Message>,
209}
210
211impl Sender {
212    /// Send a message
213    pub fn send(&self, message: Message) -> Result<(), SendError> {
214        match self.tx.try_send(message) {
215            Ok(()) => Ok(()),
216            Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => Err(SendError::ChannelFull),
217            Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => Err(SendError::ChannelClosed),
218        }
219    }
220
221    /// Send a text message
222    pub fn send_text(&self, text: impl Into<String>) -> Result<(), SendError> {
223        self.send(Message::Text(text.into().into()))
224    }
225
226    /// Send a binary message
227    pub fn send_binary(&self, data: impl Into<Vec<u8>>) -> Result<(), SendError> {
228        self.send(Message::Binary(data.into().into()))
229    }
230
231    /// Send a ping
232    pub fn ping(&self, data: impl Into<Vec<u8>>) -> Result<(), SendError> {
233        self.send(Message::Ping(data.into().into()))
234    }
235
236    /// Send a message (async, blocking on capacity)
237    pub async fn send_async(&self, message: Message) -> Result<(), SendError> {
238        self.tx
239            .send(message)
240            .await
241            .map_err(|_| SendError::ChannelClosed)
242    }
243
244    /// Send a text message (async)
245    pub async fn send_text_async(&self, text: impl Into<String>) -> Result<(), SendError> {
246        self.send_async(Message::Text(text.into().into())).await
247    }
248
249    /// Send a binary message (async)
250    pub async fn send_binary_async(&self, data: impl Into<Vec<u8>>) -> Result<(), SendError> {
251        self.send_async(Message::Binary(data.into().into())).await
252    }
253
254    /// Send a ping (async)
255    pub async fn ping_async(&self, data: impl Into<Vec<u8>>) -> Result<(), SendError> {
256        self.send_async(Message::Ping(data.into().into())).await
257    }
258
259    /// Send a message with timeout (async)
260    pub async fn send_timeout(&self, message: Message, timeout: Duration) -> Result<(), SendError> {
261        match tokio::time::timeout(timeout, self.tx.send(message)).await {
262            Ok(Ok(())) => Ok(()),
263            Ok(Err(_)) => Err(SendError::ChannelClosed),
264            Err(_) => Err(SendError::Timeout(timeout)),
265        }
266    }
267}
268
269/// Internal runtime that coordinates messaging state and lifecycle control.
270struct ClientRuntime {
271    is_running: AtomicBool,
272    cancel: CancellationToken,
273    message_tx: broadcast::Sender<SharedMessage>,
274    send_tx: Arc<RwLock<Option<mpsc::Sender<Message>>>>,
275    dispatcher: Arc<MessageDispatcher<crate::connection::DefaultWsStream>>,
276    run_state: watch::Sender<bool>,
277}
278
279impl ClientRuntime {
280    fn new(config: &ClientConfig) -> Self {
281        let (message_tx, _) = broadcast::channel(config.channel_buffer_size);
282        let dispatcher_config = DispatcherConfig::from(config);
283        let (run_state, _rx) = watch::channel(false);
284
285        Self {
286            is_running: AtomicBool::new(false),
287            cancel: CancellationToken::new(),
288            message_tx,
289            send_tx: Arc::new(RwLock::new(None)),
290            dispatcher: Arc::new(MessageDispatcher::new(dispatcher_config)),
291            run_state,
292        }
293    }
294
295    fn begin_run(&self) -> Result<(), ClientError> {
296        if self.is_running.swap(true, Ordering::SeqCst) {
297            Err(ClientError::AlreadyRunning)
298        } else {
299            let _ = self.run_state.send(true);
300            Ok(())
301        }
302    }
303
304    fn finish_run(&self) {
305        self.is_running.store(false, Ordering::SeqCst);
306        let _ = self.run_state.send(false);
307    }
308
309    fn cancel(&self) {
310        self.cancel.cancel();
311    }
312
313    fn cancel_token(&self) -> CancellationToken {
314        self.cancel.clone()
315    }
316
317    fn is_cancelled(&self) -> bool {
318        self.cancel.is_cancelled()
319    }
320
321    fn is_running(&self) -> bool {
322        self.is_running.load(Ordering::SeqCst)
323    }
324
325    fn subscribe(&self) -> broadcast::Receiver<SharedMessage> {
326        self.message_tx.subscribe()
327    }
328
329    fn message_channel(&self) -> broadcast::Sender<SharedMessage> {
330        self.message_tx.clone()
331    }
332
333    fn dispatcher(&self) -> Arc<MessageDispatcher<crate::connection::DefaultWsStream>> {
334        self.dispatcher.clone()
335    }
336
337    async fn sender(&self) -> Option<Sender> {
338        let guard = self.send_tx.read().await;
339        guard.as_ref().map(|tx| Sender { tx: tx.clone() })
340    }
341
342    async fn send(&self, message: Message) -> Result<(), SendError> {
343        let guard = self.send_tx.read().await;
344        guard.as_ref().map_or(Err(SendError::NotConnected), |tx| {
345            match tx.try_send(message) {
346                Ok(()) => Ok(()),
347                Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => Err(SendError::ChannelFull),
348                Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
349                    Err(SendError::ChannelClosed)
350                }
351            }
352        })
353    }
354
355    async fn send_async(&self, message: Message) -> Result<(), SendError> {
356        let tx = self
357            .send_tx
358            .read()
359            .await
360            .as_ref()
361            .ok_or(SendError::NotConnected)?
362            .clone();
363        tx.send(message).await.map_err(|_| SendError::ChannelClosed)
364    }
365
366    async fn send_timeout(&self, message: Message, timeout: Duration) -> Result<(), SendError> {
367        let tx = self
368            .send_tx
369            .read()
370            .await
371            .as_ref()
372            .ok_or(SendError::NotConnected)?
373            .clone();
374        match tokio::time::timeout(timeout, tx.send(message)).await {
375            Ok(Ok(())) => Ok(()),
376            Ok(Err(_)) => Err(SendError::ChannelClosed),
377            Err(_) => Err(SendError::Timeout(timeout)),
378        }
379    }
380
381    async fn set_send_channel(&self, tx: mpsc::Sender<Message>) {
382        let mut guard = self.send_tx.write().await;
383        *guard = Some(tx);
384    }
385
386    async fn clear_send_channel(&self) {
387        let mut guard = self.send_tx.write().await;
388        *guard = None;
389    }
390
391    fn run_state_receiver(&self) -> watch::Receiver<bool> {
392        self.run_state.subscribe()
393    }
394}
395
396/// WebSocket client with automatic reconnection (thin wrapper over `ConnectionSupervisor`)
397pub struct WebSocketClient {
398    /// Target URI (e.g. `<wss://example.com/ws>`)
399    uri: String,
400    config: ClientConfig,
401    handshaker: BoxHandshaker,
402    extension_host: Arc<ExtensionHost>,
403    supervisor: ConnectionSupervisor<DefaultConnector>,
404    runtime: Arc<ClientRuntime>,
405}
406
407impl WebSocketClient {
408    /// Create a new WebSocket client builder
409    pub fn builder(uri: impl Into<String>) -> WebSocketClientBuilder {
410        WebSocketClientBuilder::new(uri)
411    }
412
413    /// Create a client with default configuration
414    pub fn new(uri: impl Into<String>) -> Self {
415        Self::builder(uri).build()
416    }
417
418    /// Subscribe to incoming messages
419    ///
420    /// Returns a receiver for shared messages wrapped in `Arc<Message>` for zero-copy broadcasting.
421    /// To work with the message:
422    /// - **Read-only access**: `msg.as_ref()` or dereference `&*msg`
423    /// - **Need owned copy**: `Arc::try_unwrap(msg).unwrap_or_else(|arc| (*arc).clone())`
424    /// - **Clone specific data**: `msg.clone()` clones the Arc (cheap), `(*msg).clone()` clones the Message
425    #[must_use]
426    pub fn subscribe(&self) -> broadcast::Receiver<SharedMessage> {
427        self.runtime.subscribe()
428    }
429
430    /// Get target URI
431    #[must_use]
432    pub fn uri(&self) -> &str {
433        &self.uri
434    }
435
436    /// Subscribe to connection events
437    #[must_use]
438    pub fn subscribe_events(&self) -> broadcast::Receiver<ConnectionEvent> {
439        self.supervisor.subscribe()
440    }
441
442    /// Get a sender handle for sending messages
443    pub async fn sender(&self) -> Option<Sender> {
444        self.runtime.sender().await
445    }
446
447    /// Send a message (convenience method)
448    pub async fn send(&self, message: Message) -> Result<(), SendError> {
449        self.runtime.send(message).await
450    }
451    /// Send a message (async, blocking on capacity)
452    pub async fn send_async(&self, message: Message) -> Result<(), SendError> {
453        self.runtime.send_async(message).await
454    }
455
456    /// Send a message with timeout (async)
457    pub async fn send_timeout(&self, message: Message, timeout: Duration) -> Result<(), SendError> {
458        self.runtime.send_timeout(message, timeout).await
459    }
460
461    /// Get current connection state snapshot
462    pub async fn state(&self) -> ConnectionSnapshot {
463        self.supervisor.snapshot().await
464    }
465
466    /// Check if currently connected
467    #[must_use]
468    pub fn is_connected(&self) -> bool {
469        self.supervisor.is_connected()
470    }
471
472    /// Register an extension
473    pub async fn register_extension<E: Extension + 'static>(
474        &self,
475        extension: E,
476    ) -> Result<(), ClientError> {
477        self.extension_host
478            .register(extension)
479            .await
480            .map_err(ClientError::Extension)
481    }
482
483    /// Run the client (blocking)
484    pub async fn run(&self) -> Result<(), ClientError> {
485        self.runtime.begin_run()?;
486        let result = self.run_loop().await;
487        self.runtime.finish_run();
488        result
489    }
490
491    /// Shutdown the client
492    pub fn shutdown(&self) {
493        // Fast shutdown: immediately signal cancellation and stop supervisor.
494        // Pending messages may be dropped; use `shutdown_graceful` to wait for
495        // the run loop to finish instead.
496        self.runtime.cancel();
497        // Also request supervisor to stop any in-flight connect attempts
498        self.supervisor.shutdown();
499    }
500
501    /// Shutdown the client gracefully, waiting for the run loop to exit or timing out.
502    ///
503    /// This method triggers [`shutdown()`](Self::shutdown) and then waits for the run loop
504    /// to finish. If the client is not running, it returns immediately.
505    ///
506    /// Note: This only waits for the run loop to exit (extensions receive disconnect/shutdown
507    /// hooks). It does **not** guarantee that pending user messages make it to the peer;
508    /// callers should coordinate with their own protocols if that guarantee is required.
509    ///
510    /// # Errors
511    ///
512    /// Returns [`ClientError::ShutdownTimeout`] if the run loop does not exit within
513    /// the specified timeout duration.
514    pub async fn shutdown_graceful(&self, timeout: Duration) -> Result<(), ClientError> {
515        let mut run_state = self.runtime.run_state_receiver();
516        self.shutdown();
517        if !self.runtime.is_running() || !*run_state.borrow() {
518            return Ok(());
519        }
520
521        let wait_for_shutdown = async {
522            while run_state.changed().await.is_ok() {
523                if !*run_state.borrow() {
524                    break;
525                }
526            }
527        };
528
529        match tokio::time::timeout(timeout, wait_for_shutdown).await {
530            Ok(()) => Ok(()),
531            Err(_) => Err(ClientError::ShutdownTimeout(timeout)),
532        }
533    }
534
535    async fn run_loop(&self) -> Result<(), ClientError> {
536        loop {
537            if self.runtime.is_cancelled() {
538                tracing::info!("Shutdown requested");
539                self.extension_host.shutdown().await?;
540                return Ok(());
541            }
542            // Connect and establish a session (attach dispatcher, update extensions)
543            let (stream, mut send_rx, connection_id) = match self.establish_session().await {
544                Ok(t) => t,
545                Err(ClientError::Supervisor(SupervisorError::Shutdown)) => {
546                    tracing::info!("Supervisor shutdown requested");
547                    self.extension_host.shutdown().await?;
548                    return Ok(());
549                }
550                Err(ClientError::Handshake(_)) => {
551                    // Handshake failure already handled with delay; retry
552                    continue;
553                }
554                Err(e) => {
555                    self.extension_host.shutdown().await?;
556                    return Err(e);
557                }
558            };
559
560            // Spawn receiver and forwarder tasks
561            let (mut recv_task, forward_task) = self.spawn_receiver_and_bridge(stream);
562
563            // Drive outgoing sends and watch receiver/cancel
564            let disconnect_reason = self.drive_session(&mut send_rx, &mut recv_task).await;
565
566            // Cleanup and notify
567            self.cleanup_session(forward_task, disconnect_reason, connection_id)
568                .await?;
569        }
570    }
571
572    async fn connect_via_supervisor(
573        &self,
574    ) -> Result<crate::connection::DefaultWsStream, ClientError> {
575        match self.supervisor.connect().await {
576            Ok(stream) => Ok(stream),
577            Err(e) => Err(ClientError::Supervisor(e)),
578        }
579    }
580
581    async fn establish_session(
582        &self,
583    ) -> Result<
584        (
585            futures_util::stream::SplitStream<crate::connection::DefaultWsStream>,
586            mpsc::Receiver<Message>,
587            u64,
588        ),
589        ClientError,
590    > {
591        let ws_stream = self.connect_via_supervisor().await?;
592
593        // Split stream
594        let (mut sink, mut stream) = ws_stream.split();
595
596        // Create send channel
597        let (send_tx, send_rx) = mpsc::channel::<Message>(self.config.send_queue_capacity);
598        self.runtime.set_send_channel(send_tx).await;
599
600        // Perform handshake
601        if let Err(e) = self.perform_handshake(&mut sink, &mut stream).await {
602            tracing::error!(error = ?e, "Handshake failed");
603            self.supervisor
604                .mark_disconnected(DisconnectReason::Error(e.to_string()))
605                .await;
606            // Respect handshaker retryability semantics
607            if self.handshaker.is_retryable(&e) {
608                tokio::time::sleep(self.config.handshake_retry_delay).await;
609                return Err(ClientError::Handshake(e));
610            }
611            // Graceful stop: emit Fatal event, request supervisor shutdown, and surface Shutdown to run_loop
612            self.supervisor
613                .fatal(crate::error::ConnectError::HandshakeFailed(e.to_string()));
614            self.supervisor.shutdown();
615            return Err(ClientError::Supervisor(SupervisorError::Shutdown));
616        }
617
618        // Attach dispatcher for sending path
619        self.runtime.dispatcher().attach(sink).await;
620
621        // Update extension context
622        let connection_id = self.supervisor.connection_id();
623        let snapshot = self.supervisor.snapshot().await;
624        self.extension_host
625            .update_context(connection_id, snapshot.reconnect_count)
626            .await;
627        let _ = self.extension_host.notify_connect().await;
628        tracing::info!(connection_id = connection_id, "Connected");
629
630        Ok((stream, send_rx, connection_id))
631    }
632
633    fn spawn_receiver_and_bridge(
634        &self,
635        stream: futures_util::stream::SplitStream<crate::connection::DefaultWsStream>,
636    ) -> (
637        tokio::task::JoinHandle<Result<(), crate::error::ReceiveError>>,
638        tokio::task::JoinHandle<()>,
639    ) {
640        // Bridge dispatcher broadcasts to client-level message_tx
641        let dispatcher = self.runtime.dispatcher();
642        let ext_host = self.extension_host.clone();
643        let mut disp_rx = dispatcher.subscribe();
644        let client_broadcast = self.runtime.message_channel();
645        let cancel_token = self.runtime.cancel_token();
646        let forward_task = tokio::spawn(async move {
647            loop {
648                tokio::select! {
649                    () = cancel_token.cancelled() => break,
650                    msg = disp_rx.recv() => {
651                        match msg {
652                            Ok(m) => { let _ = client_broadcast.send(m); }
653                            Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => { /* drop lagged */ }
654                            Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
655                        }
656                    }
657                }
658            }
659        });
660        // Activity handle for last-activity updates
661        let activity = self.supervisor.activity_handle();
662        let recv_task = tokio::spawn(async move {
663            dispatcher
664                .receive_loop_with_processor(
665                    stream,
666                    move || {
667                        let activity = activity.clone();
668                        async move { activity.update().await }
669                    },
670                    move |msg| {
671                        let ext_host = ext_host.clone();
672                        async move { ext_host.process_message(&msg).await }
673                    },
674                )
675                .await
676        });
677
678        (recv_task, forward_task)
679    }
680
681    async fn drive_session(
682        &self,
683        send_rx: &mut mpsc::Receiver<Message>,
684        recv_task: &mut tokio::task::JoinHandle<Result<(), crate::error::ReceiveError>>,
685    ) -> Option<DisconnectReason> {
686        let cancel_token = self.runtime.cancel_token();
687        let dispatcher = self.runtime.dispatcher();
688        loop {
689            tokio::select! {
690                () = cancel_token.cancelled() => {
691                    recv_task.abort();
692                    return Some(DisconnectReason::Shutdown);
693                }
694                res = &mut *recv_task => {
695                    return match res {
696                        Ok(Ok(())) => Some(DisconnectReason::Normal),
697                        Ok(Err(e)) => Some(match e {
698                            crate::error::ReceiveError::Timeout(_) => DisconnectReason::Timeout,
699                            crate::error::ReceiveError::StreamClosed => DisconnectReason::Normal,
700                            crate::error::ReceiveError::WebSocket(err) => DisconnectReason::Error(err),
701                        }),
702                        Err(_) => Some(DisconnectReason::Error("receiver task aborted".to_string())),
703                    }
704                }
705                msg = send_rx.recv() => {
706                    if let Some(message) = msg {
707                        if let Err(e) = dispatcher.send(message).await {
708                            return Some(DisconnectReason::Error(format!("send error: {e:?}")));
709                        }
710                    } else {
711                        return Some(DisconnectReason::Error("send channel closed".to_string()));
712                    }
713                }
714            }
715        }
716    }
717
718    async fn cleanup_session(
719        &self,
720        forward_task: tokio::task::JoinHandle<()>,
721        disconnect_reason: Option<DisconnectReason>,
722        connection_id: u64,
723    ) -> Result<(), ClientError> {
724        self.runtime.clear_send_channel().await;
725        forward_task.abort();
726        self.runtime.dispatcher().detach().await;
727
728        let reason = disconnect_reason.unwrap_or(DisconnectReason::Normal);
729        self.supervisor.mark_disconnected(reason.clone()).await;
730        let _ = self.extension_host.notify_disconnect().await;
731
732        tracing::info!(
733            connection_id = connection_id,
734            reason = ?Some(reason.clone()),
735            "Disconnected"
736        );
737
738        if self.runtime.is_cancelled() {
739            tracing::info!("Shutdown requested after disconnect");
740            self.extension_host.shutdown().await?;
741            return Ok(());
742        }
743
744        Ok(())
745    }
746
747    async fn perform_handshake<S, St>(
748        &self,
749        sink: &mut S,
750        stream: &mut St,
751    ) -> Result<(), HandshakeError>
752    where
753        S: futures_util::Sink<Message, Error = tungstenite::Error> + Unpin + Send,
754        St: futures_util::Stream<Item = Result<Message, tungstenite::Error>> + Unpin + Send,
755    {
756        use crate::context::ConnectionContext;
757
758        let snapshot = self.supervisor.snapshot().await;
759        let context =
760            ConnectionContext::new(snapshot.id).with_reconnect_count(snapshot.reconnect_count);
761
762        self.handshaker
763            .handshake_with_timeout(sink, stream, &context)
764            .await
765    }
766
767    // old message_loop removed; receive path is now fully handled by MessageDispatcher
768}
769
770/// Builder for `WebSocketClient`
771pub struct WebSocketClientBuilder {
772    uri: String,
773    config: ClientConfig,
774    retry_strategy: Box<dyn RetryStrategy>,
775    handshaker: BoxHandshaker,
776}
777
778impl WebSocketClientBuilder {
779    /// Create a new builder
780    pub fn new(uri: impl Into<String>) -> Self {
781        Self {
782            uri: uri.into(),
783            config: ClientConfig::default(),
784            retry_strategy: Box::new(ExponentialBackoff::default()),
785            handshaker: Box::new(NoOpHandshaker),
786        }
787    }
788
789    /// Set client configuration
790    #[must_use]
791    #[allow(clippy::missing_const_for_fn)] // ClientConfig is not const-compatible
792    pub fn config(mut self, config: ClientConfig) -> Self {
793        self.config = config;
794        self
795    }
796
797    /// Set receive timeout
798    #[must_use]
799    pub const fn receive_timeout(mut self, timeout: Duration) -> Self {
800        self.config.receive_timeout = timeout;
801        self
802    }
803
804    /// Set retry strategy
805    #[must_use]
806    pub fn retry_strategy<S: RetryStrategy + 'static>(mut self, strategy: S) -> Self {
807        self.retry_strategy = Box::new(strategy);
808        self
809    }
810
811    /// Set handshaker
812    #[must_use]
813    pub fn handshaker<H: Handshaker + 'static>(mut self, handshaker: H) -> Self {
814        self.handshaker = Box::new(handshaker);
815        self
816    }
817
818    /// Disable retry (single connection attempt)
819    #[must_use]
820    pub fn no_retry(mut self) -> Self {
821        self.retry_strategy = Box::new(NoRetry);
822        self
823    }
824
825    /// Use exponential backoff with custom parameters
826    #[must_use]
827    pub fn exponential_backoff(
828        mut self,
829        initial: Duration,
830        max: Duration,
831        multiplier: f64,
832    ) -> Self {
833        self.retry_strategy = Box::new(
834            ExponentialBackoff::new(initial, max)
835                .with_factor(multiplier)
836                .with_jitter(0.1),
837        );
838        self
839    }
840
841    /// Build the client
842    #[must_use]
843    pub fn build(self) -> WebSocketClient {
844        let runtime = Arc::new(ClientRuntime::new(&self.config));
845
846        // Build connector from config
847        let connector = DefaultConnector::new()
848            .with_nodelay(self.config.disable_nagle)
849            // ws_config is Option<WebSocketConfig>; DefaultConnector::with_ws_config takes a concrete value
850            // Apply only if provided
851        ;
852        let connector = if let Some(ws_cfg) = self.config.ws_config {
853            connector.with_ws_config(ws_cfg)
854        } else {
855            connector
856        };
857
858        // Build supervisor config from client config + retry strategy
859        let mut sup_cfg = crate::connection::SupervisorConfig::new();
860        sup_cfg.retry_strategy = self.retry_strategy;
861        sup_cfg.exit_on_first_failure = self.config.exit_on_first_failure;
862        sup_cfg.connect_timeout = self.config.connect_timeout;
863
864        let supervisor =
865            ConnectionSupervisor::with_connector(self.uri.clone(), connector).with_config(sup_cfg);
866
867        WebSocketClient {
868            uri: self.uri,
869            config: self.config,
870            handshaker: self.handshaker,
871            extension_host: Arc::new(ExtensionHost::new()),
872            supervisor,
873            runtime,
874        }
875    }
876}
877
878/// Extension trait for running client in background
879pub trait WebSocketClientExt {
880    /// Spawn the client as a background task
881    fn spawn(self: Arc<Self>) -> tokio::task::JoinHandle<Result<(), ClientError>>;
882}
883
884impl WebSocketClientExt for WebSocketClient {
885    fn spawn(self: Arc<Self>) -> tokio::task::JoinHandle<Result<(), ClientError>> {
886        tokio::spawn(async move { self.run().await })
887    }
888}
889
890#[cfg(test)]
891mod tests {
892    use super::*;
893
894    #[test]
895    fn test_client_config_defaults() {
896        let config = ClientConfig::default();
897        assert_eq!(config.receive_timeout, Duration::from_secs(20));
898        assert!(!config.exit_on_first_failure);
899        assert!(!config.disable_nagle);
900    }
901
902    #[test]
903    fn test_client_config_presets() {
904        let fast = ClientConfig::fast_reconnect();
905        assert_eq!(fast.receive_timeout, Duration::from_secs(10));
906        assert!(fast.disable_nagle);
907
908        let stable = ClientConfig::stable_connection();
909        assert_eq!(stable.receive_timeout, Duration::from_secs(60));
910    }
911
912    #[test]
913    fn test_builder() {
914        let client = WebSocketClient::builder("ws://localhost:8080")
915            .receive_timeout(Duration::from_secs(30))
916            .no_retry()
917            .build();
918
919        assert_eq!(client.config.receive_timeout, Duration::from_secs(30));
920    }
921
922    #[tokio::test]
923    async fn test_sender_backpressure_full() {
924        use tokio::sync::mpsc;
925        // Capacity 1 channel
926        let (tx, mut _rx) = mpsc::channel::<Message>(1);
927        let sender = Sender { tx };
928
929        // Fill the channel
930        assert!(sender.send(Message::Text("a".into())).is_ok());
931        // Next try should be full
932        let res = sender.send(Message::Text("b".into()));
933        assert!(matches!(res, Err(crate::error::SendError::ChannelFull)));
934    }
935
936    #[tokio::test]
937    async fn test_client_shutdown_exits_quickly() {
938        // Build a normal client but cancel before running to avoid any network attempt
939        let client = WebSocketClient::builder("wss://example.test/ws")
940            .receive_timeout(std::time::Duration::from_millis(100))
941            .no_retry()
942            .build();
943        let client = std::sync::Arc::new(client);
944
945        // Cancel before run starts
946        client.shutdown();
947
948        let h = {
949            let c = client.clone();
950            tokio::spawn(async move { c.run().await })
951        };
952
953        let res = tokio::time::timeout(std::time::Duration::from_secs(1), h).await;
954        assert!(res.is_ok(), "run() did not exit in time");
955        let run_res = res.unwrap().unwrap();
956        assert!(run_res.is_ok(), "run() returned error: {run_res:?}");
957    }
958}