stream-tungstenite 0.6.1

A streaming implementation of the Tungstenite WebSocket protocol
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
//! WebSocket client with automatic reconnection.
//!
//! This module provides the main WebSocket client API that integrates:
//! - Connection management with automatic reconnection
//! - Customizable retry strategies
//! - Application-level handshakes
//! - Extension system for lifecycle and message handling
//!
//! # Example
//!
//! ```rust,ignore
//! use stream_tungstenite::{WebSocketClient, ClientConfig};
//!
//! // Create a client with default configuration
//! let client = WebSocketClient::new("wss://example.com/ws");
//!
//! // Subscribe to messages
//! let mut messages = client.subscribe();
//!
//! // Send messages
//! client.send(Message::Text("hello".into())).await.unwrap();
//!
//! // Receive messages
//! while let Some(msg) = messages.recv().await {
//!     println!("Received: {:?}", msg);
//! }
//! ```

use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;

use futures_util::StreamExt;
use tokio::sync::{broadcast, mpsc, watch, RwLock};
use tokio_util::sync::CancellationToken;
use tungstenite::protocol::WebSocketConfig;
use tungstenite::Message;

use crate::connection::ConnectionEvent;
use crate::connection::{
    ConnectionSnapshot, ConnectionSupervisor, DefaultConnector, ExponentialBackoff, NoRetry,
    RetryStrategy,
};
use crate::error::{ClientError, DisconnectReason, HandshakeError, SendError, SupervisorError};
use crate::extension::{Extension, ExtensionHost};
use crate::handshake::{BoxHandshaker, Handshaker, NoOpHandshaker};
use crate::message::{DispatcherConfig, MessageDispatcher, ProcessorErrorPolicy, SharedMessage};

// (Reserved) WebSocket stream type alias (default connector's stream)
// type WsStream = DefaultWsStream;

/// Configuration for the WebSocket client
#[derive(Clone)]
pub struct ClientConfig {
    /// Receive timeout - disconnect if no message received within this duration
    pub receive_timeout: Duration,
    /// Whether to exit immediately if the first connection fails
    pub exit_on_first_failure: bool,
    /// Connection timeout for establishing TCP/WebSocket
    pub connect_timeout: Duration,
    /// Delay before retrying after handshake failures
    pub handshake_retry_delay: Duration,
    /// WebSocket protocol configuration
    pub ws_config: Option<WebSocketConfig>,
    /// Disable Nagle's algorithm for lower latency
    pub disable_nagle: bool,
    /// Channel buffer size for message broadcasting
    pub channel_buffer_size: usize,
    /// Outgoing send queue capacity (bounded channel for backpressure)
    pub send_queue_capacity: usize,
    /// Policy for handling extension processor errors
    pub processor_error_policy: ProcessorErrorPolicy,
}

impl Default for ClientConfig {
    fn default() -> Self {
        Self {
            receive_timeout: Duration::from_secs(20),
            exit_on_first_failure: false,
            connect_timeout: Duration::from_secs(30),
            handshake_retry_delay: Duration::from_secs(5),
            ws_config: None,
            disable_nagle: false,
            channel_buffer_size: 256,
            send_queue_capacity: 256,
            processor_error_policy: ProcessorErrorPolicy::Ignore,
        }
    }
}

impl ClientConfig {
    /// Create a new configuration with default values
    #[must_use]
    pub fn new() -> Self {
        Self::default()
    }

    /// Set the receive timeout
    #[must_use]
    pub const fn with_receive_timeout(mut self, timeout: Duration) -> Self {
        self.receive_timeout = timeout;
        self
    }

    /// Set whether to exit on first connection failure
    #[must_use]
    pub const fn with_exit_on_first_failure(mut self, exit: bool) -> Self {
        self.exit_on_first_failure = exit;
        self
    }

    /// Set connect timeout
    #[must_use]
    pub const fn with_connect_timeout(mut self, timeout: Duration) -> Self {
        self.connect_timeout = timeout;
        self
    }

    /// Set handshake retry delay
    #[must_use]
    pub const fn with_handshake_retry_delay(mut self, delay: Duration) -> Self {
        self.handshake_retry_delay = delay;
        self
    }

    /// Set WebSocket protocol configuration
    #[must_use]
    #[allow(clippy::missing_const_for_fn)] // WebSocketConfig is not const-compatible
    pub fn with_ws_config(mut self, config: WebSocketConfig) -> Self {
        self.ws_config = Some(config);
        self
    }

    /// Disable Nagle's algorithm
    #[must_use]
    pub const fn with_nodelay(mut self, nodelay: bool) -> Self {
        self.disable_nagle = nodelay;
        self
    }

    /// Set channel buffer size
    #[must_use]
    pub const fn with_channel_buffer(mut self, size: usize) -> Self {
        self.channel_buffer_size = size;
        self
    }

    /// Set bounded send queue capacity (for backpressure)
    #[must_use]
    pub const fn with_send_queue_capacity(mut self, cap: usize) -> Self {
        self.send_queue_capacity = cap;
        self
    }

    /// Set processor error handling policy for message extensions
    #[must_use]
    pub const fn with_processor_error_policy(mut self, policy: ProcessorErrorPolicy) -> Self {
        self.processor_error_policy = policy;
        self
    }

    /// Preset: Fast reconnection for low-latency scenarios
    #[must_use]
    pub const fn fast_reconnect() -> Self {
        Self {
            receive_timeout: Duration::from_secs(10),
            exit_on_first_failure: false,
            connect_timeout: Duration::from_secs(10),
            handshake_retry_delay: Duration::from_millis(500),
            ws_config: None,
            disable_nagle: true,
            channel_buffer_size: 512,
            send_queue_capacity: 512,
            processor_error_policy: ProcessorErrorPolicy::Ignore,
        }
    }

    /// Preset: Stable connection for long-running scenarios
    #[must_use]
    pub const fn stable_connection() -> Self {
        Self {
            receive_timeout: Duration::from_secs(60),
            exit_on_first_failure: false,
            connect_timeout: Duration::from_secs(60),
            handshake_retry_delay: Duration::from_secs(2),
            ws_config: None,
            disable_nagle: false,
            channel_buffer_size: 128,
            send_queue_capacity: 128,
            processor_error_policy: ProcessorErrorPolicy::Ignore,
        }
    }
}

impl From<&ClientConfig> for DispatcherConfig {
    fn from(config: &ClientConfig) -> Self {
        Self::new()
            .with_receive_timeout(config.receive_timeout)
            .with_broadcast_capacity(config.channel_buffer_size)
            .with_send_buffer_capacity(config.send_queue_capacity)
            .with_processor_error_policy(config.processor_error_policy)
    }
}

/// Sender handle for sending messages
#[derive(Clone)]
pub struct Sender {
    tx: mpsc::Sender<Message>,
}

impl Sender {
    /// Send a message
    pub fn send(&self, message: Message) -> Result<(), SendError> {
        match self.tx.try_send(message) {
            Ok(()) => Ok(()),
            Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => Err(SendError::ChannelFull),
            Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => Err(SendError::ChannelClosed),
        }
    }

    /// Send a text message
    pub fn send_text(&self, text: impl Into<String>) -> Result<(), SendError> {
        self.send(Message::Text(text.into().into()))
    }

    /// Send a binary message
    pub fn send_binary(&self, data: impl Into<Vec<u8>>) -> Result<(), SendError> {
        self.send(Message::Binary(data.into().into()))
    }

    /// Send a ping
    pub fn ping(&self, data: impl Into<Vec<u8>>) -> Result<(), SendError> {
        self.send(Message::Ping(data.into().into()))
    }

    /// Send a message (async, blocking on capacity)
    pub async fn send_async(&self, message: Message) -> Result<(), SendError> {
        self.tx
            .send(message)
            .await
            .map_err(|_| SendError::ChannelClosed)
    }

    /// Send a text message (async)
    pub async fn send_text_async(&self, text: impl Into<String>) -> Result<(), SendError> {
        self.send_async(Message::Text(text.into().into())).await
    }

    /// Send a binary message (async)
    pub async fn send_binary_async(&self, data: impl Into<Vec<u8>>) -> Result<(), SendError> {
        self.send_async(Message::Binary(data.into().into())).await
    }

    /// Send a ping (async)
    pub async fn ping_async(&self, data: impl Into<Vec<u8>>) -> Result<(), SendError> {
        self.send_async(Message::Ping(data.into().into())).await
    }

    /// Send a message with timeout (async)
    pub async fn send_timeout(&self, message: Message, timeout: Duration) -> Result<(), SendError> {
        match tokio::time::timeout(timeout, self.tx.send(message)).await {
            Ok(Ok(())) => Ok(()),
            Ok(Err(_)) => Err(SendError::ChannelClosed),
            Err(_) => Err(SendError::Timeout(timeout)),
        }
    }
}

/// Internal runtime that coordinates messaging state and lifecycle control.
struct ClientRuntime {
    is_running: AtomicBool,
    cancel: CancellationToken,
    message_tx: broadcast::Sender<SharedMessage>,
    send_tx: Arc<RwLock<Option<mpsc::Sender<Message>>>>,
    dispatcher: Arc<MessageDispatcher<crate::connection::DefaultWsStream>>,
    run_state: watch::Sender<bool>,
}

impl ClientRuntime {
    fn new(config: &ClientConfig) -> Self {
        let (message_tx, _) = broadcast::channel(config.channel_buffer_size);
        let dispatcher_config = DispatcherConfig::from(config);
        let (run_state, _rx) = watch::channel(false);

        Self {
            is_running: AtomicBool::new(false),
            cancel: CancellationToken::new(),
            message_tx,
            send_tx: Arc::new(RwLock::new(None)),
            dispatcher: Arc::new(MessageDispatcher::new(dispatcher_config)),
            run_state,
        }
    }

    fn begin_run(&self) -> Result<(), ClientError> {
        if self.is_running.swap(true, Ordering::SeqCst) {
            Err(ClientError::AlreadyRunning)
        } else {
            let _ = self.run_state.send(true);
            Ok(())
        }
    }

    fn finish_run(&self) {
        self.is_running.store(false, Ordering::SeqCst);
        let _ = self.run_state.send(false);
    }

    fn cancel(&self) {
        self.cancel.cancel();
    }

    fn cancel_token(&self) -> CancellationToken {
        self.cancel.clone()
    }

    fn is_cancelled(&self) -> bool {
        self.cancel.is_cancelled()
    }

    fn is_running(&self) -> bool {
        self.is_running.load(Ordering::SeqCst)
    }

    fn subscribe(&self) -> broadcast::Receiver<SharedMessage> {
        self.message_tx.subscribe()
    }

    fn message_channel(&self) -> broadcast::Sender<SharedMessage> {
        self.message_tx.clone()
    }

    fn dispatcher(&self) -> Arc<MessageDispatcher<crate::connection::DefaultWsStream>> {
        self.dispatcher.clone()
    }

    async fn sender(&self) -> Option<Sender> {
        let guard = self.send_tx.read().await;
        guard.as_ref().map(|tx| Sender { tx: tx.clone() })
    }

    async fn send(&self, message: Message) -> Result<(), SendError> {
        let guard = self.send_tx.read().await;
        guard.as_ref().map_or(Err(SendError::NotConnected), |tx| {
            match tx.try_send(message) {
                Ok(()) => Ok(()),
                Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => Err(SendError::ChannelFull),
                Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
                    Err(SendError::ChannelClosed)
                }
            }
        })
    }

    async fn send_async(&self, message: Message) -> Result<(), SendError> {
        let tx = self
            .send_tx
            .read()
            .await
            .as_ref()
            .ok_or(SendError::NotConnected)?
            .clone();
        tx.send(message).await.map_err(|_| SendError::ChannelClosed)
    }

    async fn send_timeout(&self, message: Message, timeout: Duration) -> Result<(), SendError> {
        let tx = self
            .send_tx
            .read()
            .await
            .as_ref()
            .ok_or(SendError::NotConnected)?
            .clone();
        match tokio::time::timeout(timeout, tx.send(message)).await {
            Ok(Ok(())) => Ok(()),
            Ok(Err(_)) => Err(SendError::ChannelClosed),
            Err(_) => Err(SendError::Timeout(timeout)),
        }
    }

    async fn set_send_channel(&self, tx: mpsc::Sender<Message>) {
        let mut guard = self.send_tx.write().await;
        *guard = Some(tx);
    }

    async fn clear_send_channel(&self) {
        let mut guard = self.send_tx.write().await;
        *guard = None;
    }

    fn run_state_receiver(&self) -> watch::Receiver<bool> {
        self.run_state.subscribe()
    }
}

/// WebSocket client with automatic reconnection (thin wrapper over `ConnectionSupervisor`)
pub struct WebSocketClient {
    /// Target URI (e.g. `<wss://example.com/ws>`)
    uri: String,
    config: ClientConfig,
    handshaker: BoxHandshaker,
    extension_host: Arc<ExtensionHost>,
    supervisor: ConnectionSupervisor<DefaultConnector>,
    runtime: Arc<ClientRuntime>,
}

impl WebSocketClient {
    /// Create a new WebSocket client builder
    pub fn builder(uri: impl Into<String>) -> WebSocketClientBuilder {
        WebSocketClientBuilder::new(uri)
    }

    /// Create a client with default configuration
    pub fn new(uri: impl Into<String>) -> Self {
        Self::builder(uri).build()
    }

    /// Subscribe to incoming messages
    ///
    /// Returns a receiver for shared messages wrapped in `Arc<Message>` for zero-copy broadcasting.
    /// To work with the message:
    /// - **Read-only access**: `msg.as_ref()` or dereference `&*msg`
    /// - **Need owned copy**: `Arc::try_unwrap(msg).unwrap_or_else(|arc| (*arc).clone())`
    /// - **Clone specific data**: `msg.clone()` clones the Arc (cheap), `(*msg).clone()` clones the Message
    #[must_use]
    pub fn subscribe(&self) -> broadcast::Receiver<SharedMessage> {
        self.runtime.subscribe()
    }

    /// Get target URI
    #[must_use]
    pub fn uri(&self) -> &str {
        &self.uri
    }

    /// Subscribe to connection events
    #[must_use]
    pub fn subscribe_events(&self) -> broadcast::Receiver<ConnectionEvent> {
        self.supervisor.subscribe()
    }

    /// Get a sender handle for sending messages
    pub async fn sender(&self) -> Option<Sender> {
        self.runtime.sender().await
    }

    /// Send a message (convenience method)
    pub async fn send(&self, message: Message) -> Result<(), SendError> {
        self.runtime.send(message).await
    }
    /// Send a message (async, blocking on capacity)
    pub async fn send_async(&self, message: Message) -> Result<(), SendError> {
        self.runtime.send_async(message).await
    }

    /// Send a message with timeout (async)
    pub async fn send_timeout(&self, message: Message, timeout: Duration) -> Result<(), SendError> {
        self.runtime.send_timeout(message, timeout).await
    }

    /// Get current connection state snapshot
    pub async fn state(&self) -> ConnectionSnapshot {
        self.supervisor.snapshot().await
    }

    /// Check if currently connected
    #[must_use]
    pub fn is_connected(&self) -> bool {
        self.supervisor.is_connected()
    }

    /// Register an extension
    pub async fn register_extension<E: Extension + 'static>(
        &self,
        extension: E,
    ) -> Result<(), ClientError> {
        self.extension_host
            .register(extension)
            .await
            .map_err(ClientError::Extension)
    }

    /// Run the client (blocking)
    pub async fn run(&self) -> Result<(), ClientError> {
        self.runtime.begin_run()?;
        let result = self.run_loop().await;
        self.runtime.finish_run();
        result
    }

    /// Shutdown the client
    pub fn shutdown(&self) {
        // Fast shutdown: immediately signal cancellation and stop supervisor.
        // Pending messages may be dropped; use `shutdown_graceful` to wait for
        // the run loop to finish instead.
        self.runtime.cancel();
        // Also request supervisor to stop any in-flight connect attempts
        self.supervisor.shutdown();
    }

    /// Shutdown the client gracefully, waiting for the run loop to exit or timing out.
    ///
    /// This method triggers [`shutdown()`](Self::shutdown) and then waits for the run loop
    /// to finish. If the client is not running, it returns immediately.
    ///
    /// Note: This only waits for the run loop to exit (extensions receive disconnect/shutdown
    /// hooks). It does **not** guarantee that pending user messages make it to the peer;
    /// callers should coordinate with their own protocols if that guarantee is required.
    ///
    /// # Errors
    ///
    /// Returns [`ClientError::ShutdownTimeout`] if the run loop does not exit within
    /// the specified timeout duration.
    pub async fn shutdown_graceful(&self, timeout: Duration) -> Result<(), ClientError> {
        let mut run_state = self.runtime.run_state_receiver();
        self.shutdown();
        if !self.runtime.is_running() || !*run_state.borrow() {
            return Ok(());
        }

        let wait_for_shutdown = async {
            while run_state.changed().await.is_ok() {
                if !*run_state.borrow() {
                    break;
                }
            }
        };

        match tokio::time::timeout(timeout, wait_for_shutdown).await {
            Ok(()) => Ok(()),
            Err(_) => Err(ClientError::ShutdownTimeout(timeout)),
        }
    }

    async fn run_loop(&self) -> Result<(), ClientError> {
        loop {
            if self.runtime.is_cancelled() {
                tracing::info!("Shutdown requested");
                self.extension_host.shutdown().await?;
                return Ok(());
            }
            // Connect and establish a session (attach dispatcher, update extensions)
            let (stream, mut send_rx, connection_id) = match self.establish_session().await {
                Ok(t) => t,
                Err(ClientError::Supervisor(SupervisorError::Shutdown)) => {
                    tracing::info!("Supervisor shutdown requested");
                    self.extension_host.shutdown().await?;
                    return Ok(());
                }
                Err(ClientError::Handshake(_)) => {
                    // Handshake failure already handled with delay; retry
                    continue;
                }
                Err(e) => {
                    self.extension_host.shutdown().await?;
                    return Err(e);
                }
            };

            // Spawn receiver and forwarder tasks
            let (mut recv_task, forward_task) = self.spawn_receiver_and_bridge(stream);

            // Drive outgoing sends and watch receiver/cancel
            let disconnect_reason = self.drive_session(&mut send_rx, &mut recv_task).await;

            // Cleanup and notify
            self.cleanup_session(forward_task, disconnect_reason, connection_id)
                .await?;
        }
    }

    async fn connect_via_supervisor(
        &self,
    ) -> Result<crate::connection::DefaultWsStream, ClientError> {
        match self.supervisor.connect().await {
            Ok(stream) => Ok(stream),
            Err(e) => Err(ClientError::Supervisor(e)),
        }
    }

    async fn establish_session(
        &self,
    ) -> Result<
        (
            futures_util::stream::SplitStream<crate::connection::DefaultWsStream>,
            mpsc::Receiver<Message>,
            u64,
        ),
        ClientError,
    > {
        let ws_stream = self.connect_via_supervisor().await?;

        // Split stream
        let (mut sink, mut stream) = ws_stream.split();

        // Create send channel
        let (send_tx, send_rx) = mpsc::channel::<Message>(self.config.send_queue_capacity);
        self.runtime.set_send_channel(send_tx).await;

        // Perform handshake
        if let Err(e) = self.perform_handshake(&mut sink, &mut stream).await {
            tracing::error!(error = ?e, "Handshake failed");
            self.supervisor
                .mark_disconnected(DisconnectReason::Error(e.to_string()))
                .await;
            // Respect handshaker retryability semantics
            if self.handshaker.is_retryable(&e) {
                tokio::time::sleep(self.config.handshake_retry_delay).await;
                return Err(ClientError::Handshake(e));
            }
            // Graceful stop: emit Fatal event, request supervisor shutdown, and surface Shutdown to run_loop
            self.supervisor
                .fatal(crate::error::ConnectError::HandshakeFailed(e.to_string()));
            self.supervisor.shutdown();
            return Err(ClientError::Supervisor(SupervisorError::Shutdown));
        }

        // Attach dispatcher for sending path
        self.runtime.dispatcher().attach(sink).await;

        // Update extension context
        let connection_id = self.supervisor.connection_id();
        let snapshot = self.supervisor.snapshot().await;
        self.extension_host
            .update_context(connection_id, snapshot.reconnect_count)
            .await;
        let _ = self.extension_host.notify_connect().await;
        tracing::info!(connection_id = connection_id, "Connected");

        Ok((stream, send_rx, connection_id))
    }

    fn spawn_receiver_and_bridge(
        &self,
        stream: futures_util::stream::SplitStream<crate::connection::DefaultWsStream>,
    ) -> (
        tokio::task::JoinHandle<Result<(), crate::error::ReceiveError>>,
        tokio::task::JoinHandle<()>,
    ) {
        // Bridge dispatcher broadcasts to client-level message_tx
        let dispatcher = self.runtime.dispatcher();
        let ext_host = self.extension_host.clone();
        let mut disp_rx = dispatcher.subscribe();
        let client_broadcast = self.runtime.message_channel();
        let cancel_token = self.runtime.cancel_token();
        let forward_task = tokio::spawn(async move {
            loop {
                tokio::select! {
                    () = cancel_token.cancelled() => break,
                    msg = disp_rx.recv() => {
                        match msg {
                            Ok(m) => { let _ = client_broadcast.send(m); }
                            Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => { /* drop lagged */ }
                            Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
                        }
                    }
                }
            }
        });
        // Activity handle for last-activity updates
        let activity = self.supervisor.activity_handle();
        let recv_task = tokio::spawn(async move {
            dispatcher
                .receive_loop_with_processor(
                    stream,
                    move || {
                        let activity = activity.clone();
                        async move { activity.update().await }
                    },
                    move |msg| {
                        let ext_host = ext_host.clone();
                        async move { ext_host.process_message(&msg).await }
                    },
                )
                .await
        });

        (recv_task, forward_task)
    }

    async fn drive_session(
        &self,
        send_rx: &mut mpsc::Receiver<Message>,
        recv_task: &mut tokio::task::JoinHandle<Result<(), crate::error::ReceiveError>>,
    ) -> Option<DisconnectReason> {
        let cancel_token = self.runtime.cancel_token();
        let dispatcher = self.runtime.dispatcher();
        loop {
            tokio::select! {
                () = cancel_token.cancelled() => {
                    recv_task.abort();
                    return Some(DisconnectReason::Shutdown);
                }
                res = &mut *recv_task => {
                    return match res {
                        Ok(Ok(())) => Some(DisconnectReason::Normal),
                        Ok(Err(e)) => Some(match e {
                            crate::error::ReceiveError::Timeout(_) => DisconnectReason::Timeout,
                            crate::error::ReceiveError::StreamClosed => DisconnectReason::Normal,
                            crate::error::ReceiveError::WebSocket(err) => DisconnectReason::Error(err),
                        }),
                        Err(_) => Some(DisconnectReason::Error("receiver task aborted".to_string())),
                    }
                }
                msg = send_rx.recv() => {
                    if let Some(message) = msg {
                        if let Err(e) = dispatcher.send(message).await {
                            return Some(DisconnectReason::Error(format!("send error: {e:?}")));
                        }
                    } else {
                        return Some(DisconnectReason::Error("send channel closed".to_string()));
                    }
                }
            }
        }
    }

    async fn cleanup_session(
        &self,
        forward_task: tokio::task::JoinHandle<()>,
        disconnect_reason: Option<DisconnectReason>,
        connection_id: u64,
    ) -> Result<(), ClientError> {
        self.runtime.clear_send_channel().await;
        forward_task.abort();
        self.runtime.dispatcher().detach().await;

        let reason = disconnect_reason.unwrap_or(DisconnectReason::Normal);
        self.supervisor.mark_disconnected(reason.clone()).await;
        let _ = self.extension_host.notify_disconnect().await;

        tracing::info!(
            connection_id = connection_id,
            reason = ?Some(reason.clone()),
            "Disconnected"
        );

        if self.runtime.is_cancelled() {
            tracing::info!("Shutdown requested after disconnect");
            self.extension_host.shutdown().await?;
            return Ok(());
        }

        Ok(())
    }

    async fn perform_handshake<S, St>(
        &self,
        sink: &mut S,
        stream: &mut St,
    ) -> Result<(), HandshakeError>
    where
        S: futures_util::Sink<Message, Error = tungstenite::Error> + Unpin + Send,
        St: futures_util::Stream<Item = Result<Message, tungstenite::Error>> + Unpin + Send,
    {
        use crate::context::ConnectionContext;

        let snapshot = self.supervisor.snapshot().await;
        let context =
            ConnectionContext::new(snapshot.id).with_reconnect_count(snapshot.reconnect_count);

        self.handshaker
            .handshake_with_timeout(sink, stream, &context)
            .await
    }

    // old message_loop removed; receive path is now fully handled by MessageDispatcher
}

/// Builder for `WebSocketClient`
pub struct WebSocketClientBuilder {
    uri: String,
    config: ClientConfig,
    retry_strategy: Box<dyn RetryStrategy>,
    handshaker: BoxHandshaker,
}

impl WebSocketClientBuilder {
    /// Create a new builder
    pub fn new(uri: impl Into<String>) -> Self {
        Self {
            uri: uri.into(),
            config: ClientConfig::default(),
            retry_strategy: Box::new(ExponentialBackoff::default()),
            handshaker: Box::new(NoOpHandshaker),
        }
    }

    /// Set client configuration
    #[must_use]
    #[allow(clippy::missing_const_for_fn)] // ClientConfig is not const-compatible
    pub fn config(mut self, config: ClientConfig) -> Self {
        self.config = config;
        self
    }

    /// Set receive timeout
    #[must_use]
    pub const fn receive_timeout(mut self, timeout: Duration) -> Self {
        self.config.receive_timeout = timeout;
        self
    }

    /// Set retry strategy
    #[must_use]
    pub fn retry_strategy<S: RetryStrategy + 'static>(mut self, strategy: S) -> Self {
        self.retry_strategy = Box::new(strategy);
        self
    }

    /// Set handshaker
    #[must_use]
    pub fn handshaker<H: Handshaker + 'static>(mut self, handshaker: H) -> Self {
        self.handshaker = Box::new(handshaker);
        self
    }

    /// Disable retry (single connection attempt)
    #[must_use]
    pub fn no_retry(mut self) -> Self {
        self.retry_strategy = Box::new(NoRetry);
        self
    }

    /// Use exponential backoff with custom parameters
    #[must_use]
    pub fn exponential_backoff(
        mut self,
        initial: Duration,
        max: Duration,
        multiplier: f64,
    ) -> Self {
        self.retry_strategy = Box::new(
            ExponentialBackoff::new(initial, max)
                .with_factor(multiplier)
                .with_jitter(0.1),
        );
        self
    }

    /// Build the client
    #[must_use]
    pub fn build(self) -> WebSocketClient {
        let runtime = Arc::new(ClientRuntime::new(&self.config));

        // Build connector from config
        let connector = DefaultConnector::new()
            .with_nodelay(self.config.disable_nagle)
            // ws_config is Option<WebSocketConfig>; DefaultConnector::with_ws_config takes a concrete value
            // Apply only if provided
        ;
        let connector = if let Some(ws_cfg) = self.config.ws_config {
            connector.with_ws_config(ws_cfg)
        } else {
            connector
        };

        // Build supervisor config from client config + retry strategy
        let mut sup_cfg = crate::connection::SupervisorConfig::new();
        sup_cfg.retry_strategy = self.retry_strategy;
        sup_cfg.exit_on_first_failure = self.config.exit_on_first_failure;
        sup_cfg.connect_timeout = self.config.connect_timeout;

        let supervisor =
            ConnectionSupervisor::with_connector(self.uri.clone(), connector).with_config(sup_cfg);

        WebSocketClient {
            uri: self.uri,
            config: self.config,
            handshaker: self.handshaker,
            extension_host: Arc::new(ExtensionHost::new()),
            supervisor,
            runtime,
        }
    }
}

/// Extension trait for running client in background
pub trait WebSocketClientExt {
    /// Spawn the client as a background task
    fn spawn(self: Arc<Self>) -> tokio::task::JoinHandle<Result<(), ClientError>>;
}

impl WebSocketClientExt for WebSocketClient {
    fn spawn(self: Arc<Self>) -> tokio::task::JoinHandle<Result<(), ClientError>> {
        tokio::spawn(async move { self.run().await })
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_client_config_defaults() {
        let config = ClientConfig::default();
        assert_eq!(config.receive_timeout, Duration::from_secs(20));
        assert!(!config.exit_on_first_failure);
        assert!(!config.disable_nagle);
    }

    #[test]
    fn test_client_config_presets() {
        let fast = ClientConfig::fast_reconnect();
        assert_eq!(fast.receive_timeout, Duration::from_secs(10));
        assert!(fast.disable_nagle);

        let stable = ClientConfig::stable_connection();
        assert_eq!(stable.receive_timeout, Duration::from_secs(60));
    }

    #[test]
    fn test_builder() {
        let client = WebSocketClient::builder("ws://localhost:8080")
            .receive_timeout(Duration::from_secs(30))
            .no_retry()
            .build();

        assert_eq!(client.config.receive_timeout, Duration::from_secs(30));
    }

    #[tokio::test]
    async fn test_sender_backpressure_full() {
        use tokio::sync::mpsc;
        // Capacity 1 channel
        let (tx, mut _rx) = mpsc::channel::<Message>(1);
        let sender = Sender { tx };

        // Fill the channel
        assert!(sender.send(Message::Text("a".into())).is_ok());
        // Next try should be full
        let res = sender.send(Message::Text("b".into()));
        assert!(matches!(res, Err(crate::error::SendError::ChannelFull)));
    }

    #[tokio::test]
    async fn test_client_shutdown_exits_quickly() {
        // Build a normal client but cancel before running to avoid any network attempt
        let client = WebSocketClient::builder("wss://example.test/ws")
            .receive_timeout(std::time::Duration::from_millis(100))
            .no_retry()
            .build();
        let client = std::sync::Arc::new(client);

        // Cancel before run starts
        client.shutdown();

        let h = {
            let c = client.clone();
            tokio::spawn(async move { c.run().await })
        };

        let res = tokio::time::timeout(std::time::Duration::from_secs(1), h).await;
        assert!(res.is_ok(), "run() did not exit in time");
        let run_res = res.unwrap().unwrap();
        assert!(run_res.is_ok(), "run() returned error: {run_res:?}");
    }
}