Skip to main content

hpx_yawc/native/
mod.rs

1//! Native WebSocket implementation for Tokio runtime.
2//!
3//! # Architecture Overview
4//!
5//! This module provides the core WebSocket implementation with two complementary APIs:
6//!
7//! - **[`WebSocket`]**: High-level API with automatic fragment assembly, compression, and protocol handling
8//! - **[`Streaming`]**: Low-level API for manual fragment control and streaming compression
9//!
10//! ## WebSocket Layer Responsibilities
11//!
12//! The [`WebSocket`] type provides automatic handling of:
13//!
14//! - **Fragment assembly**: Automatically reassembles fragmented messages into complete payloads
15//! - **Message compression**: Applies permessage-deflate (RFC 7692) decompression to complete messages
16//! - **UTF-8 validation**: Validates text frames contain valid UTF-8 (when enabled)
17//! - **Protocol control**: Handles Ping/Pong and Close frames automatically
18//! - **Automatic fragmentation**: Optionally splits large outgoing messages into fragments
19//!
20//! ## Streaming Layer Responsibilities
21//!
22//! The [`Streaming`] type provides direct control over:
23//!
24//! - **Manual fragments**: Send and receive individual frame fragments without assembly
25//! - **Streaming compression**: Compress data incrementally with partial flushes
26//! - **Memory efficiency**: Process large messages without buffering entire payloads
27//! - **Frame-level access**: Direct access to WebSocket frames for custom protocols
28//!
29//! ## Complete Architecture Stack
30//!
31//! ```text
32//! ┌────────────────────────────────────────────────┐
33//! │  Application Layer                             │
34//! └──────────────────┬─────────────────────────────┘
35//!                    │
36//! ┌──────────────────▼─────────────────────────────┐
37//! │  WebSocket Layer (automatic mode)              │
38//! │  • Automatic fragment assembly                 │
39//! │  • Message-level compression                   │
40//! │  • UTF-8 validation for text frames            │
41//! │  • Automatic Ping/Pong/Close handling          │
42//! │  • Automatic fragmentation (optional)          │
43//! └──────────────────┬─────────────────────────────┘
44//!                    │
45//!          ┌─────────▼─────────┐
46//!          │  .into_streaming()│
47//!          └─────────┬─────────┘
48//!                    │
49//! ┌──────────────────▼─────────────────────────────┐
50//! │  Streaming Layer (manual mode)                 │
51//! │  • Manual fragment control                     │
52//! │  • Streaming compression with partial flushes  │
53//! │  • Direct frame access                         │
54//! │  • Memory-efficient large message handling     │
55//! └──────────────────┬─────────────────────────────┘
56//!                    │
57//! ┌──────────────────▼─────────────────────────────┐
58//! │  ReadHalf / WriteHalf                          │
59//! │  • Connection state management                 │
60//! │  • Buffer coordination                         │
61//! └──────────────────┬─────────────────────────────┘
62//!                    │
63//! ┌──────────────────▼─────────────────────────────┐
64//! │  Codec Layer                                   │
65//! │  • Frame encoding/decoding                     │
66//! │  • Masking/unmasking                           │
67//! │  • Header parsing (FIN, RSV, OpCode)           │
68//! └──────────────────┬─────────────────────────────┘
69//!                    │
70//!              Network (TCP/TLS)
71//! ```
72//!
73//! ## Example: WebSocket Mode (Automatic)
74//!
75//! When receiving a compressed fragmented message with [`WebSocket`]:
76//!
77//! 1. **Codec**: Decodes 3 individual frames from network bytes
78//!    - `Frame(Text, RSV1=1, FIN=0, payload1)`
79//!    - `Frame(Continuation, RSV1=0, FIN=0, payload2)`
80//!    - `Frame(Continuation, RSV1=0, FIN=1, payload3)`
81//!
82//! 2. **WebSocket**: Assembles and decompresses automatically
83//!    - Concatenates: `payload1 + payload2 + payload3`
84//!    - Decompresses the complete message
85//!    - Validates UTF-8 for text frames
86//!    - Returns: `Frame(Text, payload="decompressed data")`
87//!
88//! ## Example: Streaming Mode (Manual)
89//!
90//! When receiving the same message with [`Streaming`]:
91//!
92//! 1. **Codec**: Decodes individual frames (same as above)
93//!
94//! 2. **Streaming**: Returns each frame individually
95//!    - Application receives: `Frame(Text, RSV1=1, FIN=0, payload1)`
96//!    - Application receives: `Frame(Continuation, RSV1=0, FIN=0, payload2)`
97//!    - Application receives: `Frame(Continuation, RSV1=0, FIN=1, payload3)`
98//!    - Application handles assembly and decompression manually
99//!
100//! This allows applications to stream-decompress data as fragments arrive,
101//! enabling memory-efficient processing of large messages.
102//!
103//! ## Use Case Selection
104//!
105//! **Use [`WebSocket`] when:**
106//! - You want automatic protocol handling
107//! - Messages fit comfortably in memory
108//! - You need simple, ergonomic APIs
109//!
110//! **Use [`Streaming`] when:**
111//! - Processing messages larger than available memory (e.g., file transfers)
112//! - Implementing custom fragmentation strategies
113//! - Building low-latency systems requiring frame-level control
114//! - Streaming compression is needed for real-time data
115
116mod builder;
117mod options;
118mod split;
119pub mod streaming;
120mod upgrade;
121
122use std::{
123    borrow::BorrowMut,
124    collections::VecDeque,
125    future::poll_fn,
126    io,
127    net::SocketAddr,
128    pin::{Pin, pin},
129    str::FromStr,
130    sync::Arc,
131    task::{Context, Poll, ready},
132    time::{Duration, Instant},
133};
134
135pub use builder::{HttpRequest, HttpRequestBuilder, WebSocketBuilder};
136use bytes::Bytes;
137use codec::Codec;
138use compression::{Compressor, Decompressor, WebSocketExtensions};
139pub use frame::{Frame, OpCode};
140use futures::{SinkExt, task::AtomicWaker};
141use http_body_util::Empty;
142use hyper::{Request, Response, StatusCode, body::Incoming, header, upgrade::Upgraded};
143use hyper_util::rt::TokioIo;
144pub use options::{CompressionLevel, DeflateOptions, Fragmentation, Options};
145pub use split::{ReadHalf, WriteHalf};
146use tokio::{
147    io::{AsyncRead, AsyncWrite},
148    net::TcpStream,
149};
150use tokio_rustls::{
151    TlsConnector,
152    rustls::{
153        self,
154        pki_types::{ServerName, TrustAnchor},
155    },
156};
157use tokio_util::codec::Framed;
158pub use upgrade::UpgradeFut;
159use url::Url;
160
161// Re-exports
162pub use crate::stream::MaybeTlsStream;
163use crate::{Result, WebSocketError, codec, compression, frame, streaming::Streaming};
164
165/// Type alias for WebSocket connections established via `connect`.
166///
167/// This is the default WebSocket type returned by [`WebSocket::connect`],
168/// which handles both plain TCP and TLS connections over TCP streams.
169pub type TcpWebSocket = WebSocket<MaybeTlsStream<TcpStream>>;
170
171/// Type alias for server-side WebSocket connections from HTTP upgrades.
172///
173/// This is the WebSocket type returned by [`WebSocket::upgrade`] and [`UpgradeFut`],
174/// which wraps hyper's upgraded HTTP connections.
175pub type HttpWebSocket = WebSocket<HttpStream>;
176
177#[cfg(feature = "axum")]
178pub use upgrade::IncomingUpgrade;
179
180/// The maximum allowed payload size for reading, set to 1 MiB.
181///
182/// Frames with a payload size larger than this limit will be rejected to ensure memory safety
183/// and prevent excessively large messages from impacting performance.
184pub const MAX_PAYLOAD_READ: usize = 1024 * 1024;
185
186/// The maximum allowed read buffer size, set to 2 MiB.
187///
188/// When the read buffer exceeds this size, it will close the connection
189/// to prevent unbounded memory growth from fragmented messages.
190pub const MAX_READ_BUFFER: usize = 2 * 1024 * 1024;
191
192/// Type alias for HTTP responses used during WebSocket upgrade.
193///
194/// This alias represents the HTTP response sent back to clients during a WebSocket handshake.
195/// It encapsulates a response with empty body content, which is standard for WebSocket upgrades
196/// as the connection transitions from HTTP to the WebSocket protocol after handshake completion.
197///
198/// Used in conjunction with the [`UpgradeResult`] type to provide the necessary response headers
199/// for protocol switching during the WebSocket handshake process.
200pub type HttpResponse = Response<Empty<Bytes>>;
201
202/// The result type returned by WebSocket upgrade operations.
203///
204/// This type represents the result of a server-side WebSocket upgrade attempt, containing:
205/// - An HTTP response with the appropriate WebSocket upgrade headers to send to the client
206/// - A future that will resolve to a WebSocket connection once the protocol switch is complete
207///
208/// Both components must be handled for a successful upgrade:
209/// 1. Send the HTTP response to the client
210/// 2. Await the future to obtain the WebSocket connection
211pub type UpgradeResult = Result<(HttpResponse, UpgradeFut)>;
212
213/// Parameters negotiated with the client or the server.
214#[derive(Debug, Default, Clone)]
215pub(crate) struct Negotiation {
216    pub(crate) extensions: Option<WebSocketExtensions>,
217    pub(crate) compression_level: Option<CompressionLevel>,
218    pub(crate) max_payload_read: usize,
219    pub(crate) max_read_buffer: usize,
220    pub(crate) utf8: bool,
221    pub(crate) fragmentation: Option<options::Fragmentation>,
222    pub(crate) max_backpressure_write_boundary: Option<usize>,
223}
224
225impl Negotiation {
226    pub(crate) fn decompressor(&self, role: Role) -> Option<Decompressor> {
227        let config = self.extensions.as_ref()?;
228
229        tracing::debug!(
230            "Established decompressor for {role} with settings \
231            client_no_context_takeover={} server_no_context_takeover={} \
232            server_max_window_bits={:?} client_max_window_bits={:?}",
233            config.client_no_context_takeover,
234            config.client_no_context_takeover,
235            config.server_max_window_bits,
236            config.client_max_window_bits
237        );
238
239        // configure the decompressor using the assigned role and preferred flags.
240        Some(if role == Role::Server {
241            if config.client_no_context_takeover {
242                Decompressor::no_context_takeover()
243            } else {
244                #[cfg(feature = "zlib")]
245                if let Some(Some(window_bits)) = config.client_max_window_bits {
246                    Decompressor::new_with_window_bits(window_bits.max(9))
247                } else {
248                    Decompressor::new()
249                }
250                #[cfg(not(feature = "zlib"))]
251                Decompressor::new()
252            }
253        } else {
254            // client
255            if config.server_no_context_takeover {
256                Decompressor::no_context_takeover()
257            } else {
258                #[cfg(feature = "zlib")]
259                if let Some(Some(window_bits)) = config.server_max_window_bits {
260                    Decompressor::new_with_window_bits(window_bits)
261                } else {
262                    Decompressor::new()
263                }
264                #[cfg(not(feature = "zlib"))]
265                Decompressor::new()
266            }
267        })
268    }
269
270    pub(crate) fn compressor(&self, role: Role) -> Option<Compressor> {
271        let config = self.extensions.as_ref()?;
272
273        tracing::debug!(
274            "Established compressor for {role} with settings \
275            client_no_context_takeover={} server_no_context_takeover={} \
276            server_max_window_bits={:?} client_max_window_bits={:?}",
277            config.client_no_context_takeover,
278            config.client_no_context_takeover,
279            config.server_max_window_bits,
280            config.client_max_window_bits
281        );
282
283        let level = self.compression_level?;
284
285        // configure the compressor using the assigned role and preferred flags.
286        Some(if role == Role::Client {
287            if config.client_no_context_takeover {
288                Compressor::no_context_takeover(level)
289            } else {
290                #[cfg(feature = "zlib")]
291                if let Some(Some(window_bits)) = config.client_max_window_bits {
292                    Compressor::new_with_window_bits(level, window_bits)
293                } else {
294                    Compressor::new(level)
295                }
296                #[cfg(not(feature = "zlib"))]
297                Compressor::new(level)
298            }
299        } else {
300            // server
301            if config.server_no_context_takeover {
302                Compressor::no_context_takeover(level)
303            } else {
304                #[cfg(feature = "zlib")]
305                if let Some(Some(window_bits)) = config.server_max_window_bits {
306                    Compressor::new_with_window_bits(level, window_bits)
307                } else {
308                    Compressor::new(level)
309                }
310                #[cfg(not(feature = "zlib"))]
311                Compressor::new(level)
312            }
313        })
314    }
315}
316
317/// The role the WebSocket stream is taking.
318///
319/// When a server role is taken the frames will not be masked, unlike
320/// the client role, in which frames are masked.
321#[derive(Copy, Clone, PartialEq)]
322pub enum Role {
323    Server,
324    Client,
325}
326
327impl std::fmt::Display for Role {
328    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
329        match self {
330            Self::Server => write!(f, "server"),
331            Self::Client => write!(f, "client"),
332        }
333    }
334}
335
336/// Type of context for wake operations in the WebSocket stream.
337///
338/// Used to distinguish between read and write operations when managing
339/// task waking in the WebSocket's asynchronous I/O operations.
340#[derive(Clone, Copy)]
341enum ContextKind {
342    /// Indicates a read operation context
343    Read,
344    /// Indicates a write operation context
345    Write,
346}
347
348/// Manages separate wakers for read and write operations on the WebSocket stream.
349#[derive(Default)]
350struct WakeProxy {
351    /// Waker for read operations
352    read_waker: AtomicWaker,
353    /// Waker for write operations
354    write_waker: AtomicWaker,
355}
356
357impl futures::task::ArcWake for WakeProxy {
358    fn wake_by_ref(this: &Arc<Self>) {
359        this.read_waker.wake();
360        this.write_waker.wake();
361    }
362}
363
364impl WakeProxy {
365    #[inline]
366    fn set_waker(&self, kind: ContextKind, waker: &futures::task::Waker) {
367        match kind {
368            ContextKind::Read => {
369                self.read_waker.register(waker);
370            }
371            ContextKind::Write => {
372                self.write_waker.register(waker);
373            }
374        }
375    }
376
377    #[inline(always)]
378    fn with_context<F, R>(self: &Arc<Self>, f: F) -> R
379    where
380        F: FnOnce(&mut Context<'_>) -> R,
381    {
382        let waker = futures::task::waker_ref(self);
383        let mut cx = Context::from_waker(&waker);
384        f(&mut cx)
385    }
386}
387
388/// An enum representing the underlying WebSocket stream types based on the enabled features.
389pub enum HttpStream {
390    /// The hyper-based WebSocket stream
391    Hyper(TokioIo<Upgraded>),
392}
393
394impl From<TokioIo<Upgraded>> for HttpStream {
395    fn from(value: TokioIo<Upgraded>) -> Self {
396        Self::Hyper(value)
397    }
398}
399
400impl AsyncRead for HttpStream {
401    fn poll_read(
402        self: Pin<&mut Self>,
403        cx: &mut Context<'_>,
404        buf: &mut tokio::io::ReadBuf<'_>,
405    ) -> Poll<io::Result<()>> {
406        match self.get_mut() {
407            Self::Hyper(stream) => pin!(stream).poll_read(cx, buf),
408        }
409    }
410}
411
412impl AsyncWrite for HttpStream {
413    fn poll_write(
414        self: Pin<&mut Self>,
415        cx: &mut Context<'_>,
416        buf: &[u8],
417    ) -> Poll<std::result::Result<usize, io::Error>> {
418        match self.get_mut() {
419            Self::Hyper(stream) => pin!(stream).poll_write(cx, buf),
420        }
421    }
422
423    fn poll_flush(
424        self: Pin<&mut Self>,
425        cx: &mut Context<'_>,
426    ) -> Poll<std::result::Result<(), io::Error>> {
427        match self.get_mut() {
428            Self::Hyper(stream) => pin!(stream).poll_flush(cx),
429        }
430    }
431
432    fn poll_shutdown(
433        self: Pin<&mut Self>,
434        cx: &mut Context<'_>,
435    ) -> Poll<std::result::Result<(), io::Error>> {
436        match self.get_mut() {
437            Self::Hyper(stream) => pin!(stream).poll_shutdown(cx),
438        }
439    }
440}
441
442// ================== FragmentLayer ====================
443
444pub(super) struct FragmentationState {
445    started: Instant,
446    opcode: OpCode,
447    is_compressed: bool,
448    bytes_read: usize,
449    parts: VecDeque<Bytes>,
450}
451
452/// Handles fragmentation and defragmentation of WebSocket frames.
453///
454/// This layer is responsible for:
455/// - **Outgoing fragmentation**: Breaking large frames into smaller fragments based on fragment_size
456/// - **Incoming defragmentation**: Reassembling fragmented messages received from the peer
457///
458/// The layer is owned by WebSocket and operates independently for read and write operations.
459struct FragmentLayer {
460    /// Queue for outgoing fragments (from automatic fragmentation)
461    outgoing_fragments: VecDeque<Frame>,
462    /// Fragment accumulation for assembling incoming fragmented messages
463    incoming_fragment: Option<FragmentationState>,
464    /// Maximum fragment size for outgoing messages
465    fragment_size: Option<usize>,
466    /// Maximum buffer size for incoming fragmented messages
467    max_read_buffer: usize,
468    /// Timeout for receiving complete fragmented messages
469    fragment_timeout: Option<Duration>,
470}
471
472impl FragmentLayer {
473    /// Creates a new FragmentLayer with the given configuration.
474    fn new(
475        fragment_size: Option<usize>,
476        max_read_buffer: usize,
477        fragment_timeout: Option<Duration>,
478    ) -> Self {
479        Self {
480            outgoing_fragments: VecDeque::new(),
481            incoming_fragment: None,
482            fragment_size,
483            max_read_buffer,
484            fragment_timeout,
485        }
486    }
487
488    /// Fragments an outgoing frame if necessary and queues the fragments.
489    ///
490    /// Panics if the user tries to manually fragment while auto-fragmentation is enabled.
491    fn fragment_outgoing(&mut self, frame: Frame) {
492        // Check for invalid manual fragmentation with auto-fragmentation enabled
493        if !frame.is_fin() && self.fragment_size.is_some() {
494            panic!(
495                "Fragment the frames yourself or use `fragment_size`, but not both. Use Streaming"
496            );
497        }
498
499        let max_fragment_size = self.fragment_size.unwrap_or(usize::MAX);
500        self.outgoing_fragments
501            .extend(frame.into_fragments(max_fragment_size));
502    }
503
504    /// Returns the next queued outgoing fragment, if any.
505    #[inline(always)]
506    fn pop_outgoing_fragment(&mut self) -> Option<Frame> {
507        self.outgoing_fragments.pop_front()
508    }
509
510    /// Returns true if there are pending outgoing fragments.
511    #[inline(always)]
512    fn has_outgoing_fragments(&self) -> bool {
513        !self.outgoing_fragments.is_empty()
514    }
515
516    /// Processes an incoming frame, handling fragmentation and reassembly.
517    ///
518    /// Returns:
519    /// - `Ok(Some(frame))` if a complete frame is ready (either non-fragmented or fully reassembled)
520    /// - `Ok(None)` if this is a fragment and we're still waiting for more
521    /// - `Err` if there's a protocol violation or timeout
522    fn assemble_incoming(&mut self, mut frame: Frame) -> Result<Option<Frame>> {
523        use bytes::BufMut;
524
525        #[cfg(test)]
526        println!(
527            "<<Fragmentation<< OpCode={:?} fin={} len={}",
528            frame.opcode(),
529            frame.is_fin(),
530            frame.payload.len()
531        );
532
533        match frame.opcode {
534            OpCode::Text | OpCode::Binary => {
535                // Check for invalid fragmentation state
536                if self.incoming_fragment.is_some() {
537                    return Err(WebSocketError::InvalidFragment);
538                }
539
540                // Handle fragmented messages
541                if !frame.fin {
542                    let fragmentation = FragmentationState {
543                        started: Instant::now(),
544                        opcode: frame.opcode,
545                        is_compressed: frame.is_compressed,
546                        bytes_read: frame.payload.len(),
547                        parts: VecDeque::from([frame.payload]),
548                    };
549                    self.incoming_fragment = Some(fragmentation);
550
551                    return Ok(None);
552                }
553
554                // Non-fragmented message - return as-is
555                Ok(Some(frame))
556            }
557            OpCode::Continuation => {
558                let mut fragment = self
559                    .incoming_fragment
560                    .take()
561                    .ok_or_else(|| WebSocketError::InvalidFragment)?;
562
563                fragment.bytes_read += frame.payload.len();
564
565                // Check buffer size
566                if fragment.bytes_read >= self.max_read_buffer {
567                    return Err(WebSocketError::FrameTooLarge);
568                }
569
570                // Check timeout
571                if let Some(timeout) = self.fragment_timeout
572                    && fragment.started.elapsed() > timeout
573                {
574                    return Err(WebSocketError::FragmentTimeout);
575                }
576
577                fragment.parts.push_back(frame.payload);
578
579                if frame.fin {
580                    // Assemble complete message
581                    frame.opcode = fragment.opcode;
582                    frame.is_compressed = fragment.is_compressed;
583                    frame.payload = fragment
584                        .parts
585                        .into_iter()
586                        .fold(
587                            bytes::BytesMut::with_capacity(fragment.bytes_read),
588                            |mut acc, b| {
589                                acc.put(b);
590                                acc
591                            },
592                        )
593                        .freeze();
594
595                    Ok(Some(frame))
596                } else {
597                    self.incoming_fragment = Some(fragment);
598                    Ok(None)
599                }
600            }
601            _ => {
602                // Control frames pass through unchanged
603                Ok(Some(frame))
604            }
605        }
606    }
607}
608
609// ================== WebSocket ====================
610
611/// WebSocket stream for both clients and servers.
612///
613/// The [`WebSocket`] struct manages all aspects of WebSocket communication, handling
614/// mandatory frames (close, ping, and pong frames) and protocol compliance checks.
615/// It abstracts away details related to framing and compression, which are managed
616/// by the underlying [`ReadHalf`] and [`WriteHalf`] structures.
617///
618/// A [`WebSocket`] instance can be created via high-level functions like [`WebSocket::connect`],
619/// or through a custom stream setup with [`WebSocket::handshake`].
620///
621/// # Automatic Protocol Handling
622///
623/// The WebSocket automatically handles protocol control frames:
624///
625/// - **Ping frames**: When a ping frame is received, a pong response is automatically queued for
626///   sending. The ping frame is **still returned to the application** via `next()` or the `Stream`
627///   trait, allowing you to observe incoming pings if needed.
628///
629/// - **Pong frames**: Pong frames are passed through to the application without special handling.
630///
631/// - **Close frames**: When a close frame is received, a close response is automatically sent
632///   (if not already closing). The close frame is returned to the application, and subsequent
633///   reads will fail with [`WebSocketError::ConnectionClosed`].
634///
635/// # Compression Behavior
636///
637/// When compression is enabled (via [`Options::compression`]), the WebSocket automatically
638/// compresses and decompresses messages according to RFC 7692 (permessage-deflate):
639///
640/// - **Outgoing messages**: Only complete (FIN=1) Text or Binary frames are compressed.
641///   Fragmented frames (FIN=0 or Continuation frames) are **not** compressed.
642///
643/// - **Incoming messages**: Compressed messages are automatically decompressed after
644///   fragment assembly.
645///
646/// # Automatic Fragmentation
647///
648/// When [`Options::with_max_fragment_size`] is configured, the WebSocket will automatically
649/// fragment outgoing messages that exceed the specified size limit:
650///
651/// ```no_run
652/// use futures::SinkExt;
653/// use yawc::{Frame, Options, WebSocket};
654///
655/// # async fn example() -> yawc::Result<()> {
656/// let options = Options::default().with_max_fragment_size(64 * 1024); // 64 KiB per frame
657///
658/// let mut ws = WebSocket::connect("wss://example.com/ws".parse()?)
659///     .with_options(options)
660///     .await?;
661///
662/// // This large message will be automatically split into multiple frames
663/// let large_message = vec![0u8; 200_000]; // 200 KB
664/// ws.send(Frame::binary(large_message)).await?;
665/// # Ok(())
666/// # }
667/// ```
668///
669/// **Important**: Automatic fragmentation only applies to uncompressed messages. If compression
670/// is enabled, the message is compressed first as a single unit, and only the compressed output
671/// may be fragmented if it exceeds the size limit.
672///
673/// # Manual Fragmentation with Streaming
674///
675/// `WebSocket` automatically handles fragmentation and reassembly. If you need **manual control**
676/// over individual fragments (e.g., for streaming large files or custom fragmentation logic),
677/// convert the `WebSocket` to a [`Streaming`] connection:
678///
679/// ```no_run
680/// use futures::SinkExt;
681/// use yawc::{Frame, WebSocket};
682///
683/// # async fn example() -> yawc::Result<()> {
684/// let ws = WebSocket::connect("wss://example.com/ws".parse()?).await?;
685///
686/// // Convert to Streaming for manual fragment control
687/// let mut streaming = ws.into_streaming();
688///
689/// // Manually send fragments
690/// streaming
691///     .send(Frame::text("Hello ").with_fin(false))
692///     .await?;
693/// streaming
694///     .send(Frame::continuation("World").with_fin(false))
695///     .await?;
696/// streaming.send(Frame::continuation("!")).await?;
697/// # Ok(())
698/// # }
699/// ```
700///
701/// **Important**: Attempting to send manual fragments (frames with `FIN=0`) through `WebSocket`
702/// while automatic fragmentation is enabled will panic. Use [`Streaming`] for manual fragmentation.
703///
704/// See [`examples/streaming.rs`](https://github.com/infinitefield/yawc/blob/master/examples/streaming.rs)
705/// for a complete example of manual fragment control for streaming large files.
706///
707/// # Connecting
708/// To establish a WebSocket connection as a client:
709/// ```no_run
710/// use futures::StreamExt;
711/// use tokio::net::TcpStream;
712/// use yawc::{WebSocket, frame::OpCode};
713///
714/// #[tokio::main]
715/// async fn main() -> anyhow::Result<()> {
716///     let ws = WebSocket::connect("wss://echo.websocket.org".parse()?).await?;
717///     // Use `ws` for WebSocket communication
718///     Ok(())
719/// }
720/// ```
721pub struct WebSocket<S> {
722    streaming: Streaming<S>,
723    check_utf8: bool,
724    /// Handles fragmentation and defragmentation of frames
725    fragment_layer: FragmentLayer,
726}
727
728impl WebSocket<MaybeTlsStream<TcpStream>> {
729    /// Establishes a WebSocket connection to the specified `url`.
730    ///
731    /// This asynchronous function supports both `ws://` (non-secure) and `wss://` (secure) schemes.
732    ///
733    /// # Parameters
734    /// - `url`: The WebSocket URL to connect to.
735    ///
736    /// # Returns
737    /// A `WebSocketBuilder` that can be further configured before establishing the connection.
738    ///
739    /// # Examples
740    /// ```no_run
741    /// use yawc::WebSocket;
742    ///
743    /// #[tokio::main]
744    /// async fn main() -> yawc::Result<()> {
745    ///     let ws = WebSocket::connect("wss://echo.websocket.org".parse()?).await?;
746    ///     Ok(())
747    /// }
748    /// ```
749    pub fn connect(url: Url) -> WebSocketBuilder {
750        WebSocketBuilder::new(url)
751    }
752
753    pub(crate) async fn connect_priv(
754        url: Url,
755        tcp_address: Option<SocketAddr>,
756        connector: Option<TlsConnector>,
757        options: Options,
758        builder: HttpRequestBuilder,
759    ) -> Result<TcpWebSocket> {
760        let host = url
761            .host()
762            .ok_or_else(|| WebSocketError::InvalidHttpScheme)?
763            .to_string();
764
765        let tcp_stream = if let Some(tcp_address) = tcp_address {
766            TcpStream::connect(tcp_address).await?
767        } else {
768            let port = url
769                .port_or_known_default()
770                .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "URL has no port"))?;
771            TcpStream::connect(format!("{host}:{port}")).await?
772        };
773
774        let _ = tcp_stream.set_nodelay(options.no_delay);
775
776        let stream = match url.scheme() {
777            "ws" => MaybeTlsStream::Plain(tcp_stream),
778            "wss" => {
779                let connector = connector.unwrap_or_else(tls_connector);
780                let domain = ServerName::try_from(host)
781                    .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?;
782
783                MaybeTlsStream::Tls(connector.connect(domain, tcp_stream).await?)
784            }
785            _ => return Err(WebSocketError::InvalidHttpScheme),
786        };
787
788        WebSocket::handshake_with_request(url, stream, options, builder).await
789    }
790}
791
792impl<S> WebSocket<S>
793where
794    S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
795{
796    /// Performs a WebSocket handshake over an existing connection.
797    ///
798    /// This is a lower-level API that allows you to perform a WebSocket handshake
799    /// on an already established I/O stream (such as a TcpStream or TLS stream).
800    /// For most use cases, prefer using [`WebSocket::connect`] which handles both
801    /// connection establishment and handshake automatically.
802    ///
803    /// # Arguments
804    ///
805    /// * `url` - The WebSocket URL (used for generating handshake headers)
806    /// * `io` - An existing I/O stream that implements AsyncRead + AsyncWrite
807    /// * `options` - WebSocket configuration options
808    ///
809    /// # Example
810    ///
811    /// ```no_run
812    /// use tokio::net::TcpStream;
813    /// use yawc::{Options, WebSocket};
814    ///
815    /// #[tokio::main]
816    /// async fn main() -> yawc::Result<()> {
817    ///     // Establish your own TCP connection
818    ///     let stream = TcpStream::connect("example.com:80").await?;
819    ///
820    ///     // Parse the WebSocket URL
821    ///     let url = "ws://example.com/socket".parse()?;
822    ///
823    ///     // Perform the WebSocket handshake over the existing stream
824    ///     let ws = WebSocket::handshake(url, stream, Options::default()).await?;
825    ///
826    ///     // Now you can use the WebSocket connection
827    ///     // ws.send(...).await?;
828    ///
829    ///     Ok(())
830    /// }
831    /// ```
832    ///
833    /// # Use Cases
834    ///
835    /// Use this function when you need to:
836    /// - Use a custom connection method (e.g., SOCKS proxy, custom DNS resolution)
837    /// - Reuse an existing stream or connection
838    /// - Implement custom connection logic before the WebSocket handshake
839    ///
840    /// For adding custom headers to the handshake request, use
841    /// [`WebSocket::handshake_with_request`] instead.
842    pub async fn handshake(url: Url, io: S, options: Options) -> Result<WebSocket<S>> {
843        Self::handshake_with_request(url, io, options, HttpRequest::builder()).await
844    }
845
846    /// Performs a WebSocket handshake with a customizable HTTP request.
847    ///
848    /// This is similar to [`WebSocket::handshake`] but allows you to customize
849    /// the HTTP upgrade request by providing your own [`HttpRequestBuilder`].
850    /// This is useful when you need to add custom headers (e.g., authentication
851    /// tokens, API keys, or other metadata) to the handshake request.
852    ///
853    /// # Arguments
854    ///
855    /// * `url` - The WebSocket URL (used for generating handshake headers)
856    /// * `io` - An existing I/O stream that implements AsyncRead + AsyncWrite
857    /// * `options` - WebSocket configuration options
858    /// * `builder` - An HTTP request builder for customizing the handshake request
859    ///
860    /// # Example
861    ///
862    /// ```no_run
863    /// use tokio::net::TcpStream;
864    /// use yawc::{HttpRequest, Options, WebSocket};
865    ///
866    /// #[tokio::main]
867    /// async fn main() -> yawc::Result<()> {
868    ///     // Establish your own TCP connection
869    ///     let stream = TcpStream::connect("example.com:80").await?;
870    ///
871    ///     // Parse the WebSocket URL
872    ///     let url = "ws://example.com/socket".parse()?;
873    ///
874    ///     // Create a custom HTTP request with authentication headers
875    ///     let request = HttpRequest::builder()
876    ///         .header("Authorization", "Bearer my-secret-token")
877    ///         .header("X-Custom-Header", "custom-value");
878    ///
879    ///     // Perform the WebSocket handshake with custom headers
880    ///     let ws =
881    ///         WebSocket::handshake_with_request(url, stream, Options::default(), request).await?;
882    ///
883    ///     // Now you can use the WebSocket connection
884    ///     // ws.send(...).await?;
885    ///
886    ///     Ok(())
887    /// }
888    /// ```
889    ///
890    /// # Use Cases
891    ///
892    /// Use this function when you need to:
893    /// - Add authentication headers to the handshake request
894    /// - Include custom metadata or API keys
895    /// - Control the exact HTTP request sent during the WebSocket upgrade
896    /// - Combine custom connection logic with custom headers
897    pub async fn handshake_with_request(
898        url: Url,
899        io: S,
900        options: Options,
901        mut builder: HttpRequestBuilder,
902    ) -> Result<WebSocket<S>> {
903        if !builder
904            .headers_ref()
905            .map(|h| h.contains_key(header::HOST))
906            .unwrap_or(false)
907        {
908            let host = url
909                .host()
910                .ok_or(WebSocketError::InvalidHttpScheme)?
911                .to_string();
912
913            let is_port_defined = url.port().is_some();
914            let port = url
915                .port_or_known_default()
916                .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "URL has no port"))?;
917            let host_header = if is_port_defined {
918                format!("{host}:{port}")
919            } else {
920                host
921            };
922
923            builder = builder.header(header::HOST, host_header.as_str());
924        }
925
926        let target_url = &url[url::Position::BeforePath..];
927
928        let mut req = builder
929            .method("GET")
930            .uri(target_url)
931            .header(header::UPGRADE, "websocket")
932            .header(header::CONNECTION, "upgrade")
933            .header(header::SEC_WEBSOCKET_KEY, generate_key())
934            .header(header::SEC_WEBSOCKET_VERSION, "13")
935            .body(Empty::<Bytes>::new())
936            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
937
938        if let Some(compression) = options.compression.as_ref() {
939            let extensions = WebSocketExtensions::from(compression);
940            let header_value =
941                extensions
942                    .to_string()
943                    .parse()
944                    .map_err(|e: header::InvalidHeaderValue| {
945                        io::Error::new(io::ErrorKind::InvalidInput, e.to_string())
946                    })?;
947            req.headers_mut()
948                .insert(header::SEC_WEBSOCKET_EXTENSIONS, header_value);
949        }
950
951        let (mut sender, conn) = hyper::client::conn::http1::handshake(TokioIo::new(io)).await?;
952
953        #[cfg(not(feature = "smol"))]
954        tokio::spawn(async move {
955            if let Err(err) = conn.with_upgrades().await {
956                tracing::error!("upgrading connection: {:?}", err);
957            }
958        });
959
960        #[cfg(feature = "smol")]
961        smol::spawn(async move {
962            if let Err(err) = conn.with_upgrades().await {
963                tracing::error!("upgrading connection: {:?}", err);
964            }
965        })
966        .detach();
967
968        let mut response = sender.send_request(req).await?;
969        let negotiated = verify(&response, options)?;
970
971        let upgraded = hyper::upgrade::on(&mut response).await?;
972        let parts = upgraded.downcast::<TokioIo<S>>().unwrap();
973
974        // Extract the original stream and any leftover read buffer
975        let stream = parts.io.into_inner();
976        let read_buf = parts.read_buf;
977
978        Ok(WebSocket::new(Role::Client, stream, read_buf, negotiated))
979    }
980}
981
982impl WebSocket<HttpStream> {
983    // ================== Server ====================
984
985    /// Upgrades an HTTP connection to a WebSocket one.
986    pub fn upgrade<B>(request: impl BorrowMut<Request<B>>) -> UpgradeResult {
987        Self::upgrade_with_options(request, Options::default())
988    }
989
990    /// Attempts to upgrade an incoming `hyper::Request` to a WebSocket connection with customizable options.
991    pub fn upgrade_with_options<B>(
992        mut request: impl BorrowMut<Request<B>>,
993        options: Options,
994    ) -> UpgradeResult {
995        let request = request.borrow_mut();
996
997        let key = request
998            .headers()
999            .get(header::SEC_WEBSOCKET_KEY)
1000            .ok_or(WebSocketError::MissingSecWebSocketKey)?;
1001
1002        if request
1003            .headers()
1004            .get(header::SEC_WEBSOCKET_VERSION)
1005            .map(|v| v.as_bytes())
1006            != Some(b"13")
1007        {
1008            return Err(WebSocketError::InvalidSecWebsocketVersion);
1009        }
1010
1011        let maybe_compression = request
1012            .headers()
1013            .get(header::SEC_WEBSOCKET_EXTENSIONS)
1014            .and_then(|h| h.to_str().ok())
1015            .map(WebSocketExtensions::from_str)
1016            .and_then(std::result::Result::ok);
1017
1018        let mut response = Response::builder()
1019            .status(hyper::StatusCode::SWITCHING_PROTOCOLS)
1020            .header(hyper::header::CONNECTION, "upgrade")
1021            .header(hyper::header::UPGRADE, "websocket")
1022            .header(
1023                header::SEC_WEBSOCKET_ACCEPT,
1024                upgrade::sec_websocket_protocol(key.as_bytes()),
1025            )
1026            .body(Empty::new())
1027            .map_err(|e| {
1028                io::Error::new(io::ErrorKind::InvalidInput, format!("build response: {e}"))
1029            })?;
1030
1031        let extensions = if let Some(client_compression) = maybe_compression {
1032            if let Some(server_compression) = options.compression.as_ref() {
1033                let offer = server_compression.merge(&client_compression);
1034
1035                let header_value =
1036                    offer
1037                        .to_string()
1038                        .parse()
1039                        .map_err(|e: header::InvalidHeaderValue| {
1040                            io::Error::new(
1041                                io::ErrorKind::InvalidInput,
1042                                format!("extension header: {e}"),
1043                            )
1044                        })?;
1045                response
1046                    .headers_mut()
1047                    .insert(header::SEC_WEBSOCKET_EXTENSIONS, header_value);
1048
1049                Some(offer)
1050            } else {
1051                None
1052            }
1053        } else {
1054            None
1055        };
1056
1057        let max_read_buffer = options.max_read_buffer.unwrap_or(
1058            options
1059                .max_payload_read
1060                .map(|payload_read| payload_read * 2)
1061                .unwrap_or(MAX_READ_BUFFER),
1062        );
1063
1064        let stream = UpgradeFut {
1065            inner: hyper::upgrade::on(request),
1066            negotiation: Some(Negotiation {
1067                extensions,
1068                compression_level: options
1069                    .compression
1070                    .as_ref()
1071                    .map(|compression| compression.level),
1072                max_payload_read: options.max_payload_read.unwrap_or(MAX_PAYLOAD_READ),
1073                max_read_buffer,
1074                utf8: options.check_utf8,
1075                fragmentation: options.fragmentation.clone(),
1076                max_backpressure_write_boundary: options.max_backpressure_write_boundary,
1077            }),
1078        };
1079
1080        Ok((response, stream))
1081    }
1082}
1083
1084// ======== Generic WebSocket implementation =============
1085
1086impl<S> WebSocket<S>
1087where
1088    S: AsyncRead + AsyncWrite + Unpin,
1089{
1090    /// Splits the [`WebSocket`] into its low-level components for advanced usage.
1091    ///
1092    /// # Safety
1093    /// This function is unsafe because it splits ownership of shared state.
1094    pub unsafe fn split_stream(self) -> (Framed<S, Codec>, ReadHalf, WriteHalf) {
1095        // SAFETY: caller guarantees safe usage of split ownership
1096        unsafe { self.streaming.split_stream() }
1097    }
1098
1099    /// Converts this `WebSocket` into a [`Streaming`] connection for manual fragment control.
1100    ///
1101    /// Use this when you need direct control over WebSocket frame fragmentation without
1102    /// automatic reassembly or fragmentation. This is useful for:
1103    /// - Streaming large files incrementally without loading them in memory
1104    /// - Implementing custom fragmentation strategies
1105    /// - Processing fragments as they arrive for low-latency applications
1106    /// - Fine-grained control over compression boundaries
1107    ///
1108    /// # Example
1109    ///
1110    /// ```rust,no_run
1111    /// use futures::SinkExt;
1112    /// use yawc::{Frame, WebSocket};
1113    ///
1114    /// # async fn example() -> yawc::Result<()> {
1115    /// let ws = WebSocket::connect("wss://example.com/ws".parse()?).await?;
1116    /// let mut streaming = ws.into_streaming();
1117    ///
1118    /// // Send fragments manually
1119    /// streaming
1120    ///     .send(Frame::text("Part 1").with_fin(false))
1121    ///     .await?;
1122    /// streaming.send(Frame::continuation("Part 2")).await?;
1123    /// # Ok(())
1124    /// # }
1125    /// ```
1126    ///
1127    /// See [`Streaming`] documentation and
1128    /// [`examples/streaming.rs`](https://github.com/infinitefield/yawc/blob/master/examples/streaming.rs)
1129    /// for more details.
1130    pub fn into_streaming(self) -> Streaming<S> {
1131        self.streaming
1132    }
1133
1134    /// Polls for the next frame in the WebSocket stream.
1135    pub fn poll_next_frame(&mut self, cx: &mut Context<'_>) -> Poll<Result<Frame>> {
1136        loop {
1137            let frame = ready!(self.streaming.poll_next_frame(cx))?;
1138            match self.on_frame(frame)? {
1139                Some(ok) => break Poll::Ready(Ok(ok)),
1140                None => continue,
1141            }
1142        }
1143    }
1144
1145    /// Asynchronously retrieves the next frame from the WebSocket stream.
1146    pub async fn next_frame(&mut self) -> Result<Frame> {
1147        poll_fn(|cx| self.poll_next_frame(cx)).await
1148    }
1149
1150    /// Creates a new WebSocket from an existing stream.
1151    ///
1152    /// The `read_buf` parameter should contain any bytes that were read from the stream
1153    /// during the HTTP upgrade but weren't consumed (leftover data after the HTTP response).
1154    pub(crate) fn new(role: Role, stream: S, read_buf: Bytes, opts: Negotiation) -> Self {
1155        Self {
1156            streaming: Streaming::new(role, stream, read_buf, &opts),
1157            check_utf8: opts.utf8,
1158            fragment_layer: FragmentLayer::new(
1159                opts.fragmentation.as_ref().and_then(|f| f.fragment_size),
1160                opts.max_read_buffer,
1161                opts.fragmentation.as_ref().and_then(|f| f.timeout),
1162            ),
1163        }
1164    }
1165
1166    fn on_frame(&mut self, frame: Frame) -> Result<Option<Frame>> {
1167        let frame = match self.fragment_layer.assemble_incoming(frame)? {
1168            Some(frame) => frame,
1169            None => return Ok(None), // Still waiting for more fragments
1170        };
1171
1172        if frame.opcode == OpCode::Text && self.check_utf8 {
1173            #[cfg(not(feature = "simd"))]
1174            if std::str::from_utf8(&frame.payload).is_err() {
1175                return Err(WebSocketError::InvalidUTF8);
1176            }
1177            #[cfg(feature = "simd")]
1178            if simdutf8::basic::from_utf8(&frame.payload).is_err() {
1179                return Err(WebSocketError::InvalidUTF8);
1180            }
1181        }
1182
1183        Ok(Some(frame))
1184    }
1185}
1186
1187impl<S> futures::Stream for WebSocket<S>
1188where
1189    S: AsyncRead + AsyncWrite + Unpin,
1190{
1191    type Item = Frame;
1192
1193    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1194        let this = self.get_mut();
1195        match ready!(this.poll_next_frame(cx)) {
1196            Ok(ok) => Poll::Ready(Some(ok)),
1197            Err(_) => Poll::Ready(None),
1198        }
1199    }
1200}
1201
1202impl<S> futures::Sink<Frame> for WebSocket<S>
1203where
1204    S: AsyncRead + AsyncWrite + Unpin,
1205{
1206    type Error = WebSocketError;
1207
1208    fn poll_ready(
1209        self: Pin<&mut Self>,
1210        cx: &mut Context<'_>,
1211    ) -> Poll<std::result::Result<(), Self::Error>> {
1212        let this = self.get_mut();
1213        this.streaming.poll_ready_unpin(cx)
1214    }
1215
1216    fn start_send(self: Pin<&mut Self>, item: Frame) -> std::result::Result<(), Self::Error> {
1217        let this = self.get_mut();
1218        this.fragment_layer.fragment_outgoing(item);
1219        Ok(())
1220    }
1221
1222    fn poll_flush(
1223        self: Pin<&mut Self>,
1224        cx: &mut Context<'_>,
1225    ) -> Poll<std::result::Result<(), Self::Error>> {
1226        let this = self.get_mut();
1227
1228        // First, send all queued fragments to WriteHalf
1229        while this.fragment_layer.has_outgoing_fragments() {
1230            // We need to call `poll_ready` before calling `start_send` since the user
1231            // might be under certain backpressure constraints
1232            ready!(this.streaming.poll_ready_unpin(cx))?;
1233            let fragment = this
1234                .fragment_layer
1235                .pop_outgoing_fragment()
1236                .expect("fragment");
1237            this.streaming.start_send_unpin(fragment)?;
1238        }
1239
1240        // Then flush WriteHalf
1241        this.streaming.poll_flush_unpin(cx)
1242    }
1243
1244    fn poll_close(
1245        self: Pin<&mut Self>,
1246        cx: &mut Context<'_>,
1247    ) -> Poll<std::result::Result<(), Self::Error>> {
1248        let this = self.get_mut();
1249        this.streaming.poll_close_unpin(cx)
1250    }
1251}
1252
1253// ================ Helper functions ====================
1254
1255fn verify(response: &Response<Incoming>, options: Options) -> Result<Negotiation> {
1256    if response.status() != StatusCode::SWITCHING_PROTOCOLS {
1257        return Err(WebSocketError::InvalidStatusCode(
1258            response.status().as_u16(),
1259        ));
1260    }
1261
1262    let compression_level = options.compression.as_ref().map(|opts| opts.level);
1263    let headers = response.headers();
1264
1265    if !headers
1266        .get(header::UPGRADE)
1267        .and_then(|h| h.to_str().ok())
1268        .map(|h| h.eq_ignore_ascii_case("websocket"))
1269        .unwrap_or(false)
1270    {
1271        return Err(WebSocketError::InvalidUpgradeHeader);
1272    }
1273
1274    if !headers
1275        .get(header::CONNECTION)
1276        .and_then(|h| h.to_str().ok())
1277        .map(|h| h.eq_ignore_ascii_case("Upgrade"))
1278        .unwrap_or(false)
1279    {
1280        return Err(WebSocketError::InvalidConnectionHeader);
1281    }
1282
1283    let extensions = headers
1284        .get(header::SEC_WEBSOCKET_EXTENSIONS)
1285        .and_then(|h| h.to_str().ok())
1286        .map(WebSocketExtensions::from_str)
1287        .and_then(std::result::Result::ok);
1288
1289    let max_read_buffer = options.max_read_buffer.unwrap_or(
1290        options
1291            .max_payload_read
1292            .map(|payload_read| payload_read * 2)
1293            .unwrap_or(MAX_READ_BUFFER),
1294    );
1295
1296    Ok(Negotiation {
1297        extensions,
1298        compression_level,
1299        max_payload_read: options.max_payload_read.unwrap_or(MAX_PAYLOAD_READ),
1300        max_read_buffer,
1301        utf8: options.check_utf8,
1302        fragmentation: options.fragmentation.clone(),
1303        max_backpressure_write_boundary: options.max_backpressure_write_boundary,
1304    })
1305}
1306
1307fn generate_key() -> String {
1308    use base64::prelude::*;
1309    let input: [u8; 16] = rand::random();
1310    BASE64_STANDARD.encode(input)
1311}
1312
1313/// Creates a TLS connector with root certificates for secure WebSocket connections.
1314fn tls_connector() -> TlsConnector {
1315    let mut root_cert_store = rustls::RootCertStore::empty();
1316    root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| TrustAnchor {
1317        subject: ta.subject.clone(),
1318        subject_public_key_info: ta.subject_public_key_info.clone(),
1319        name_constraints: ta.name_constraints.clone(),
1320    }));
1321
1322    let maybe_provider = rustls::crypto::CryptoProvider::get_default().cloned();
1323
1324    #[cfg(any(feature = "rustls-ring", feature = "rustls-aws-lc-rs"))]
1325    let provider = maybe_provider.unwrap_or_else(|| {
1326        #[cfg(feature = "rustls-ring")]
1327        let _provider = rustls::crypto::ring::default_provider();
1328        #[cfg(feature = "rustls-aws-lc-rs")]
1329        let _provider = rustls::crypto::aws_lc_rs::default_provider();
1330
1331        Arc::new(_provider)
1332    });
1333
1334    #[cfg(not(any(feature = "rustls-ring", feature = "rustls-aws-lc-rs")))]
1335    let provider = maybe_provider.expect(
1336        r#"No Rustls crypto provider was enabled for yawc to connect to a `wss://` endpoint!
1337
1338Either:
1339    - provide a `connector` in the WebSocketBuilder options
1340    - enable one of the following features: `rustls-ring`, `rustls-aws-lc-rs`"#,
1341    );
1342
1343    let mut config = rustls::ClientConfig::builder_with_provider(provider)
1344        .with_protocol_versions(rustls::ALL_VERSIONS)
1345        .expect("versions")
1346        .with_root_certificates(root_cert_store)
1347        .with_no_client_auth();
1348    config.alpn_protocols = vec!["http/1.1".into()];
1349
1350    TlsConnector::from(Arc::new(config))
1351}
1352
1353#[cfg(test)]
1354mod tests {
1355    use std::{
1356        pin::Pin,
1357        task::{Context, Poll},
1358    };
1359
1360    use futures::SinkExt;
1361    use tokio::io::{AsyncRead, AsyncWrite, DuplexStream, ReadBuf};
1362
1363    use super::*;
1364    use crate::close::{self, CloseCode};
1365
1366    /// A mock duplex stream that wraps tokio's DuplexStream for testing.
1367    struct MockStream {
1368        inner: DuplexStream,
1369    }
1370
1371    impl MockStream {
1372        /// Creates a pair of connected mock streams.
1373        fn pair(buffer_size: usize) -> (Self, Self) {
1374            let (a, b) = tokio::io::duplex(buffer_size);
1375            (Self { inner: a }, Self { inner: b })
1376        }
1377    }
1378
1379    impl AsyncRead for MockStream {
1380        fn poll_read(
1381            mut self: Pin<&mut Self>,
1382            cx: &mut Context<'_>,
1383            buf: &mut ReadBuf<'_>,
1384        ) -> Poll<io::Result<()>> {
1385            Pin::new(&mut self.inner).poll_read(cx, buf)
1386        }
1387    }
1388
1389    impl AsyncWrite for MockStream {
1390        fn poll_write(
1391            mut self: Pin<&mut Self>,
1392            cx: &mut Context<'_>,
1393            buf: &[u8],
1394        ) -> Poll<io::Result<usize>> {
1395            Pin::new(&mut self.inner).poll_write(cx, buf)
1396        }
1397
1398        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
1399            Pin::new(&mut self.inner).poll_flush(cx)
1400        }
1401
1402        fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
1403            Pin::new(&mut self.inner).poll_shutdown(cx)
1404        }
1405    }
1406
1407    /// Helper function to create a WebSocket pair for testing.
1408    fn create_websocket_pair(buffer_size: usize) -> (WebSocket<MockStream>, WebSocket<MockStream>) {
1409        create_websocket_pair_with_config(buffer_size, None, None)
1410    }
1411
1412    fn create_websocket_pair_with_config(
1413        buffer_size: usize,
1414        fragment_size: Option<usize>,
1415        compression_level: Option<CompressionLevel>,
1416    ) -> (WebSocket<MockStream>, WebSocket<MockStream>) {
1417        let (client_stream, server_stream) = MockStream::pair(buffer_size);
1418
1419        let extensions = compression_level.map(|_level| WebSocketExtensions {
1420            server_max_window_bits: None,
1421            client_max_window_bits: None,
1422            server_no_context_takeover: false,
1423            client_no_context_takeover: false,
1424        });
1425
1426        let negotiation = Negotiation {
1427            extensions,
1428            compression_level,
1429            max_payload_read: MAX_PAYLOAD_READ,
1430            max_read_buffer: MAX_READ_BUFFER,
1431            utf8: false,
1432            fragmentation: fragment_size.map(|size| options::Fragmentation {
1433                timeout: None,
1434                fragment_size: Some(size),
1435            }),
1436            max_backpressure_write_boundary: None,
1437        };
1438
1439        let client_ws = WebSocket::new(
1440            Role::Client,
1441            client_stream,
1442            Bytes::new(),
1443            negotiation.clone(),
1444        );
1445
1446        let server_ws = WebSocket::new(Role::Server, server_stream, Bytes::new(), negotiation);
1447
1448        (client_ws, server_ws)
1449    }
1450
1451    #[tokio::test]
1452    async fn test_send_and_receive_text_frame() {
1453        let (mut client, mut server) = create_websocket_pair(1024);
1454
1455        let text = "Hello, WebSocket!";
1456        client
1457            .send(Frame::text(text))
1458            .await
1459            .expect("Failed to send text frame");
1460
1461        let frame = server.next_frame().await.expect("Failed to receive frame");
1462
1463        assert_eq!(frame.opcode(), OpCode::Text);
1464        assert_eq!(frame.payload(), text.as_bytes());
1465        assert!(frame.is_fin());
1466    }
1467
1468    #[tokio::test]
1469    async fn test_send_and_receive_binary_frame() {
1470        let (mut client, mut server) = create_websocket_pair(1024);
1471
1472        let data = vec![1u8, 2, 3, 4, 5];
1473        client
1474            .send(Frame::binary(data.clone()))
1475            .await
1476            .expect("Failed to send binary frame");
1477
1478        let frame = server.next_frame().await.expect("Failed to receive frame");
1479
1480        assert_eq!(frame.opcode(), OpCode::Binary);
1481        assert_eq!(frame.payload(), &data[..]);
1482        assert!(frame.is_fin());
1483    }
1484
1485    #[tokio::test]
1486    async fn test_bidirectional_communication() {
1487        let (mut client, mut server) = create_websocket_pair(2048);
1488
1489        client
1490            .send(Frame::text("Client message"))
1491            .await
1492            .expect("Failed to send from client");
1493
1494        let frame = server
1495            .next_frame()
1496            .await
1497            .expect("Failed to receive at server");
1498        assert_eq!(frame.payload(), b"Client message" as &[u8]);
1499
1500        server
1501            .send(Frame::text("Server response"))
1502            .await
1503            .expect("Failed to send from server");
1504
1505        let frame = client
1506            .next_frame()
1507            .await
1508            .expect("Failed to receive at client");
1509        assert_eq!(frame.payload(), b"Server response" as &[u8]);
1510    }
1511
1512    #[tokio::test]
1513    async fn test_ping_pong() {
1514        let (mut client, mut server) = create_websocket_pair(1024);
1515
1516        // Ping frames are handled automatically by the WebSocket implementation
1517        // The server will automatically respond with a pong, but we won't receive it via next_frame()
1518        // Instead, test that we can send and receive pong frames explicitly
1519
1520        client
1521            .send(Frame::pong("pong_data"))
1522            .await
1523            .expect("Failed to send pong");
1524
1525        let frame = server.next_frame().await.expect("Failed to receive pong");
1526        assert_eq!(frame.opcode(), OpCode::Pong);
1527        assert_eq!(frame.payload(), b"pong_data" as &[u8]);
1528    }
1529
1530    #[tokio::test]
1531    async fn test_close_frame() {
1532        let (mut client, mut server) = create_websocket_pair(1024);
1533
1534        client
1535            .send(Frame::close(CloseCode::Normal, b"Goodbye"))
1536            .await
1537            .expect("Failed to send close frame");
1538
1539        let frame = server
1540            .next_frame()
1541            .await
1542            .expect("Failed to receive close frame");
1543
1544        assert_eq!(frame.opcode(), OpCode::Close);
1545        assert_eq!(frame.close_code(), Some(close::CloseCode::Normal));
1546        assert_eq!(
1547            frame.close_reason().expect("Invalid close reason"),
1548            Some("Goodbye")
1549        );
1550    }
1551
1552    #[tokio::test]
1553    async fn test_large_message() {
1554        let (mut client, mut server) = create_websocket_pair(65536);
1555
1556        let large_data = vec![42u8; 10240];
1557        client
1558            .send(Frame::binary(large_data.clone()))
1559            .await
1560            .expect("Failed to send large message");
1561
1562        let frame = server
1563            .next_frame()
1564            .await
1565            .expect("Failed to receive large message");
1566
1567        assert_eq!(frame.opcode(), OpCode::Binary);
1568        assert_eq!(frame.payload().len(), 10240);
1569        assert_eq!(frame.payload(), &large_data[..]);
1570    }
1571
1572    #[tokio::test]
1573    async fn test_multiple_messages() {
1574        let (mut client, mut server) = create_websocket_pair(4096);
1575
1576        for i in 0..10 {
1577            let msg = format!("Message {}", i);
1578            client
1579                .send(Frame::text(msg.clone()))
1580                .await
1581                .expect("Failed to send message");
1582
1583            let frame = server
1584                .next_frame()
1585                .await
1586                .expect("Failed to receive message");
1587            assert_eq!(frame.payload(), msg.as_bytes());
1588        }
1589    }
1590
1591    #[tokio::test]
1592    async fn test_empty_payload() {
1593        let (mut client, mut server) = create_websocket_pair(1024);
1594
1595        client
1596            .send(Frame::text(Bytes::new()))
1597            .await
1598            .expect("Failed to send empty frame");
1599
1600        let frame = server
1601            .next_frame()
1602            .await
1603            .expect("Failed to receive empty frame");
1604
1605        assert_eq!(frame.opcode(), OpCode::Text);
1606        assert_eq!(frame.payload().len(), 0);
1607    }
1608
1609    #[tokio::test]
1610    async fn test_fragmented_message() {
1611        let (mut client, mut server) = create_websocket_pair(2048);
1612
1613        let mut frame1 = Frame::text("Hello, ");
1614        frame1.set_fin(false);
1615        client
1616            .send(frame1)
1617            .await
1618            .expect("Failed to send first fragment");
1619
1620        let frame2 = Frame::continuation("World!");
1621        client
1622            .send(frame2)
1623            .await
1624            .expect("Failed to send final fragment");
1625
1626        // WebSocket automatically reassembles fragments
1627        // We receive one complete message with the concatenated payload
1628        let received = server
1629            .next_frame()
1630            .await
1631            .expect("Failed to receive message");
1632        assert_eq!(received.opcode(), OpCode::Text);
1633        assert!(received.is_fin());
1634        assert_eq!(received.payload(), b"Hello, World!" as &[u8]);
1635    }
1636
1637    #[tokio::test]
1638    async fn test_concurrent_send_receive() {
1639        let (mut client, mut server) = create_websocket_pair(4096);
1640
1641        let client_task = tokio::spawn(async move {
1642            for i in 0..5 {
1643                client
1644                    .send(Frame::text(format!("Client {}", i)))
1645                    .await
1646                    .expect("Failed to send from client");
1647
1648                let frame = client
1649                    .next_frame()
1650                    .await
1651                    .expect("Failed to receive at client");
1652                assert_eq!(frame.payload(), format!("Server {}", i).as_bytes());
1653            }
1654            client
1655        });
1656
1657        let server_task = tokio::spawn(async move {
1658            for i in 0..5 {
1659                let frame = server
1660                    .next_frame()
1661                    .await
1662                    .expect("Failed to receive at server");
1663                assert_eq!(frame.payload(), format!("Client {}", i).as_bytes());
1664
1665                server
1666                    .send(Frame::text(format!("Server {}", i)))
1667                    .await
1668                    .expect("Failed to send from server");
1669            }
1670            server
1671        });
1672
1673        client_task.await.expect("Client task failed");
1674        server_task.await.expect("Server task failed");
1675    }
1676
1677    #[tokio::test]
1678    async fn test_utf8_validation() {
1679        let (mut client, mut server) = create_websocket_pair(1024);
1680
1681        let valid_utf8 = "Hello, 世界! 🌍";
1682        client
1683            .send(Frame::text(valid_utf8))
1684            .await
1685            .expect("Failed to send UTF-8 text");
1686
1687        let frame = server
1688            .next_frame()
1689            .await
1690            .expect("Failed to receive UTF-8 text");
1691        assert_eq!(frame.opcode(), OpCode::Text);
1692        assert!(frame.is_utf8());
1693        assert_eq!(std::str::from_utf8(frame.payload()).unwrap(), valid_utf8);
1694    }
1695
1696    #[tokio::test]
1697    async fn test_stream_trait_implementation() {
1698        use futures::StreamExt;
1699
1700        let (mut client, mut server) = create_websocket_pair(1024);
1701
1702        tokio::spawn(async move {
1703            for i in 0..3 {
1704                client
1705                    .send(Frame::text(format!("Message {}", i)))
1706                    .await
1707                    .expect("Failed to send message");
1708            }
1709        });
1710
1711        let mut count = 0;
1712        while let Some(frame) = server.next().await {
1713            assert_eq!(frame.opcode(), OpCode::Text);
1714            count += 1;
1715            if count == 3 {
1716                break;
1717            }
1718        }
1719        assert_eq!(count, 3);
1720    }
1721
1722    #[tokio::test]
1723    async fn test_sink_trait_implementation() {
1724        use futures::SinkExt;
1725
1726        let (mut client, mut server) = create_websocket_pair(1024);
1727
1728        client
1729            .send(Frame::text("Sink message"))
1730            .await
1731            .expect("Failed to send via Sink");
1732
1733        client.flush().await.expect("Failed to flush");
1734
1735        let frame = server
1736            .next_frame()
1737            .await
1738            .expect("Failed to receive message");
1739        assert_eq!(frame.payload(), b"Sink message" as &[u8]);
1740    }
1741
1742    #[tokio::test]
1743    async fn test_rapid_small_messages() {
1744        let (mut client, mut server) = create_websocket_pair(8192);
1745
1746        let count = 100;
1747
1748        let sender = tokio::spawn(async move {
1749            for i in 0..count {
1750                client
1751                    .send(Frame::text(format!("{}", i)))
1752                    .await
1753                    .expect("Failed to send");
1754            }
1755            client
1756        });
1757
1758        for i in 0..count {
1759            let frame = server.next_frame().await.expect("Failed to receive");
1760            assert_eq!(frame.payload(), format!("{}", i).as_bytes());
1761        }
1762
1763        sender.await.expect("Sender task failed");
1764    }
1765
1766    #[tokio::test]
1767    async fn test_interleaved_control_and_data_frames() {
1768        let (mut client, mut server) = create_websocket_pair(2048);
1769
1770        client
1771            .send(Frame::text("Data 1"))
1772            .await
1773            .expect("Failed to send");
1774
1775        // Ping frames are handled automatically and don't appear in next_frame()
1776        // Use pong frames instead to test control frame interleaving
1777        client
1778            .send(Frame::pong("pong"))
1779            .await
1780            .expect("Failed to send pong");
1781
1782        client
1783            .send(Frame::binary(vec![1, 2, 3]))
1784            .await
1785            .expect("Failed to send");
1786
1787        let f1 = server.next_frame().await.expect("Failed to receive");
1788        assert_eq!(f1.opcode(), OpCode::Text);
1789        assert_eq!(f1.payload(), b"Data 1" as &[u8]);
1790
1791        let f2 = server.next_frame().await.expect("Failed to receive");
1792        assert_eq!(f2.opcode(), OpCode::Pong);
1793
1794        let f3 = server.next_frame().await.expect("Failed to receive");
1795        assert_eq!(f3.opcode(), OpCode::Binary);
1796        assert_eq!(f3.payload(), &[1u8, 2, 3] as &[u8]);
1797    }
1798
1799    #[tokio::test]
1800    async fn test_client_sends_masked_frames() {
1801        let (mut client, mut _server) = create_websocket_pair(1024);
1802
1803        // Create a frame and send it through the client
1804        let frame = Frame::text("test");
1805        client.send(frame).await.expect("Failed to send");
1806
1807        // The frame should be automatically masked by the client encoder
1808        // We can't directly verify this without inspecting the wire format,
1809        // but the test verifies the codec path works correctly
1810    }
1811
1812    #[tokio::test]
1813    async fn test_server_sends_unmasked_frames() {
1814        let (mut _client, mut server) = create_websocket_pair(1024);
1815
1816        // Server frames should not be masked
1817        let frame = Frame::text("test");
1818        server.send(frame).await.expect("Failed to send");
1819
1820        // Similar to above - verifies the codec path
1821    }
1822
1823    #[tokio::test]
1824    async fn test_close_code_variants() {
1825        let (mut client, mut server) = create_websocket_pair(1024);
1826
1827        client
1828            .send(Frame::close(close::CloseCode::Away, b""))
1829            .await
1830            .expect("Failed to send close");
1831
1832        let frame = server.next_frame().await.expect("Failed to receive");
1833        assert_eq!(frame.close_code(), Some(close::CloseCode::Away));
1834    }
1835
1836    #[tokio::test]
1837    async fn test_multiple_fragments() {
1838        let (mut client, mut server) = create_websocket_pair(4096);
1839
1840        // Send 5 fragments
1841        for i in 0..5 {
1842            let is_last = i == 4;
1843            let opcode = if i == 0 {
1844                OpCode::Text
1845            } else {
1846                OpCode::Continuation
1847            };
1848
1849            let mut frame = Frame::from((opcode, format!("part{}", i)));
1850            frame.set_fin(is_last);
1851            client.send(frame).await.expect("Failed to send fragment");
1852        }
1853
1854        // WebSocket automatically reassembles fragments
1855        // We receive one complete message, not individual fragments
1856        let frame = server.next_frame().await.expect("Failed to receive");
1857        assert_eq!(frame.opcode(), OpCode::Text);
1858        assert!(frame.is_fin());
1859
1860        // The payload should be the concatenation of all fragments
1861        let expected = "part0part1part2part3part4";
1862        assert_eq!(frame.payload(), expected.as_bytes());
1863    }
1864
1865    #[tokio::test]
1866    async fn test_automatic_fragmentation_large_messages() {
1867        // Create WebSocket pair with fragment_size set to 100 bytes
1868        let (client_stream, server_stream) = MockStream::pair(8192);
1869
1870        let negotiation = Negotiation {
1871            extensions: None,
1872            compression_level: None,
1873            max_payload_read: MAX_PAYLOAD_READ,
1874            max_read_buffer: MAX_READ_BUFFER,
1875            utf8: false,
1876            fragmentation: Some(options::Fragmentation {
1877                timeout: None,
1878                fragment_size: Some(100),
1879            }),
1880            max_backpressure_write_boundary: None,
1881        };
1882
1883        let mut client_ws = WebSocket::new(
1884            Role::Client,
1885            client_stream,
1886            Bytes::new(),
1887            negotiation.clone(),
1888        );
1889
1890        let mut server_ws = WebSocket::new(Role::Server, server_stream, Bytes::new(), negotiation);
1891
1892        // Send a large message (300 bytes) from client
1893        let large_payload = vec![b'A'; 300];
1894        client_ws
1895            .send(Frame::binary(large_payload.clone()))
1896            .await
1897            .unwrap();
1898
1899        // Server should receive the complete message (reassembled from fragments)
1900        let received = server_ws.next_frame().await.unwrap();
1901        assert_eq!(received.opcode(), OpCode::Binary);
1902        assert_eq!(received.payload(), large_payload.as_slice());
1903    }
1904
1905    #[tokio::test]
1906    async fn test_automatic_fragmentation_small_messages() {
1907        // Create WebSocket pair with fragment_size set to 100 bytes
1908        let (client_stream, server_stream) = MockStream::pair(8192);
1909
1910        let negotiation = Negotiation {
1911            extensions: None,
1912            compression_level: None,
1913            max_payload_read: MAX_PAYLOAD_READ,
1914            max_read_buffer: MAX_READ_BUFFER,
1915            utf8: false,
1916            fragmentation: Some(options::Fragmentation {
1917                timeout: None,
1918                fragment_size: Some(100),
1919            }),
1920            max_backpressure_write_boundary: None,
1921        };
1922
1923        let mut client_ws = WebSocket::new(
1924            Role::Client,
1925            client_stream,
1926            Bytes::new(),
1927            negotiation.clone(),
1928        );
1929
1930        let mut server_ws = WebSocket::new(Role::Server, server_stream, Bytes::new(), negotiation);
1931
1932        // Send a small message (50 bytes) from client
1933        let small_payload = vec![b'B'; 50];
1934        client_ws
1935            .send(Frame::text(small_payload.clone()))
1936            .await
1937            .unwrap();
1938
1939        // Server should receive the complete message
1940        let received = server_ws.next_frame().await.unwrap();
1941        assert_eq!(received.opcode(), OpCode::Text);
1942        assert_eq!(received.payload(), small_payload.as_slice());
1943    }
1944
1945    #[tokio::test]
1946    async fn test_no_fragmentation_when_not_configured() {
1947        // Create WebSocket pair WITHOUT fragment_size
1948        let (client_stream, server_stream) = MockStream::pair(8192);
1949
1950        let negotiation = Negotiation {
1951            extensions: None,
1952            compression_level: None,
1953            max_payload_read: MAX_PAYLOAD_READ,
1954            max_read_buffer: MAX_READ_BUFFER,
1955            utf8: false,
1956            fragmentation: None,
1957            max_backpressure_write_boundary: None,
1958        };
1959
1960        let mut client_ws = WebSocket::new(
1961            Role::Client,
1962            client_stream,
1963            Bytes::new(),
1964            negotiation.clone(),
1965        );
1966
1967        let mut server_ws = WebSocket::new(Role::Server, server_stream, Bytes::new(), negotiation);
1968
1969        // Send a large message (1000 bytes) without fragmentation config
1970        let large_payload = vec![b'C'; 1000];
1971        client_ws
1972            .send(Frame::binary(large_payload.clone()))
1973            .await
1974            .unwrap();
1975
1976        // Server should receive the complete message
1977        let received = server_ws.next_frame().await.unwrap();
1978        assert_eq!(received.opcode(), OpCode::Binary);
1979        assert_eq!(received.payload(), large_payload.as_slice());
1980    }
1981
1982    #[tokio::test]
1983    async fn test_interleave_control_frames_with_continuation_frames() {
1984        // Per RFC 6455 Section 5.5:
1985        // "Control frames themselves MUST NOT be fragmented."
1986        // "Control frames MAY be injected in the middle of a fragmented message."
1987        //
1988        // This test verifies that control frames (ping/pong) can be interleaved
1989        // with continuation frames during message fragmentation, and that the
1990        // fragmented message is still correctly reassembled.
1991        let (mut client, mut server) = create_websocket_pair(4096);
1992
1993        // Send first fragment of a text message (FIN=0)
1994        let mut fragment1 = Frame::text("Hello, ");
1995        fragment1.set_fin(false);
1996        client
1997            .send(fragment1)
1998            .await
1999            .expect("Failed to send first fragment");
2000
2001        // Interleave a ping frame in the middle of the fragmented message
2002        client
2003            .send(Frame::ping("ping during fragmentation"))
2004            .await
2005            .expect("Failed to send ping");
2006
2007        // Send second continuation fragment (FIN=0)
2008        let mut fragment2 = Frame::continuation("World");
2009        fragment2.set_fin(false);
2010        client
2011            .send(fragment2)
2012            .await
2013            .expect("Failed to send second fragment");
2014
2015        // Interleave a pong frame
2016        client
2017            .send(Frame::pong("pong during fragmentation"))
2018            .await
2019            .expect("Failed to send pong");
2020
2021        // Send final continuation fragment (FIN=1)
2022        let fragment3 = Frame::continuation("!");
2023        client
2024            .send(fragment3)
2025            .await
2026            .expect("Failed to send final fragment");
2027
2028        // Server should receive the ping frame first
2029        let ping_frame = server
2030            .next_frame()
2031            .await
2032            .expect("Failed to receive ping frame");
2033        assert_eq!(ping_frame.opcode(), OpCode::Ping);
2034        assert_eq!(ping_frame.payload(), b"ping during fragmentation" as &[u8]);
2035
2036        // Server should receive the pong frame
2037        let pong_frame = server
2038            .next_frame()
2039            .await
2040            .expect("Failed to receive pong frame");
2041        assert_eq!(pong_frame.opcode(), OpCode::Pong);
2042        assert_eq!(pong_frame.payload(), b"pong during fragmentation" as &[u8]);
2043
2044        // Server should receive the complete reassembled message
2045        let message_frame = server
2046            .next_frame()
2047            .await
2048            .expect("Failed to receive reassembled message");
2049        assert_eq!(message_frame.opcode(), OpCode::Text);
2050        assert!(message_frame.is_fin());
2051        assert_eq!(message_frame.payload(), b"Hello, World!" as &[u8]);
2052    }
2053
2054    #[tokio::test]
2055    async fn test_large_compressed_fragmented_payload() {
2056        // Test large payload with manual compression and fragmentation
2057        // Fragment size: 65536 bytes
2058        // This tests that:
2059        // 1. User can manually fragment large payloads using set_fin()
2060        // 2. Compression works across manually created fragments
2061        // 3. Decompression and reassembly produce the original payload
2062
2063        const FRAGMENT_SIZE: usize = 65536;
2064        const PAYLOAD_SIZE: usize = 1024 * 1024; // 1 MB
2065
2066        use flate2::Compression;
2067
2068        let (mut client, mut server) = create_websocket_pair_with_config(
2069            256 * 1024, // 256 KB buffer
2070            None,       // No automatic fragmentation - we do it manually
2071            Some(Compression::best()),
2072        );
2073
2074        // Create a large payload with repetitive data (compresses well)
2075        let payload: Vec<u8> = (0..PAYLOAD_SIZE).map(|i| (i % 256) as u8).collect();
2076
2077        // Manually fragment the payload and send as separate frames
2078        let total_fragments = PAYLOAD_SIZE.div_ceil(FRAGMENT_SIZE);
2079        println!(
2080            "Sending {} bytes in {} fragments of {} bytes each",
2081            PAYLOAD_SIZE, total_fragments, FRAGMENT_SIZE
2082        );
2083
2084        // Spawn server task to receive the message concurrently
2085        let server_task = tokio::spawn(async move {
2086            server
2087                .next_frame()
2088                .await
2089                .expect("Failed to receive large payload")
2090        });
2091
2092        // Send all fragments from client
2093        let mut offset = 0;
2094        let mut fragment_num = 0;
2095
2096        while offset < PAYLOAD_SIZE {
2097            let end = std::cmp::min(offset + FRAGMENT_SIZE, PAYLOAD_SIZE);
2098            let chunk = payload[offset..end].to_vec();
2099            let is_final = end == PAYLOAD_SIZE;
2100
2101            let mut frame = if fragment_num == 0 {
2102                // First frame: use Binary opcode
2103                Frame::binary(chunk)
2104            } else {
2105                // Continuation frames
2106                Frame::continuation(chunk)
2107            };
2108
2109            // Set FIN bit only on the last fragment
2110            frame.set_fin(is_final);
2111
2112            println!(
2113                "Sending fragment {}/{}: {} bytes, OpCode={:?} FIN={}",
2114                fragment_num + 1,
2115                total_fragments,
2116                frame.payload().len(),
2117                frame.opcode(),
2118                is_final
2119            );
2120
2121            client
2122                .send(frame)
2123                .await
2124                .unwrap_or_else(|_| panic!("Failed to send fragment {}", fragment_num + 1));
2125
2126            offset = end;
2127            fragment_num += 1;
2128        }
2129
2130        // Wait for server to receive the complete message
2131        let received_frame = server_task.await.expect("Server task failed");
2132
2133        // Verify the payload was reassembled correctly
2134        assert_eq!(received_frame.opcode(), OpCode::Binary);
2135        assert!(received_frame.is_fin());
2136        assert_eq!(received_frame.payload().len(), PAYLOAD_SIZE);
2137        assert_eq!(received_frame.payload().as_ref(), &payload[..]);
2138
2139        println!(
2140            "Successfully sent {} manual fragments, compressed, decompressed, and reassembled {} bytes",
2141            total_fragments, PAYLOAD_SIZE
2142        );
2143    }
2144
2145    #[tokio::test]
2146    async fn test_compressed_fragmented_with_interleaved_control() {
2147        // Test compression + fragmentation + interleaved control frames
2148        // This is the most complex scenario combining:
2149        // 1. Compression (best quality)
2150        // 2. Automatic fragmentation
2151        // 3. Control frames between fragments
2152
2153        const FRAGMENT_SIZE: usize = 65536;
2154
2155        use flate2::Compression;
2156
2157        let (mut client, mut server) = create_websocket_pair_with_config(
2158            128 * 1024,
2159            Some(FRAGMENT_SIZE),
2160            Some(Compression::best()),
2161        );
2162
2163        // Create a payload that will span multiple fragments
2164        let payload = "This is a test payload that should compress well. ".repeat(5000);
2165        let original_payload = payload.clone();
2166        let payload_bytes = payload.as_bytes().to_vec();
2167
2168        // Send the payload (will be fragmented automatically)
2169        tokio::spawn(async move {
2170            client
2171                .send(Frame::binary(payload_bytes))
2172                .await
2173                .expect("Failed to send payload");
2174
2175            // Send a ping after starting the fragmented message
2176            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
2177            client
2178                .send(Frame::ping("test"))
2179                .await
2180                .expect("Failed to send ping");
2181        });
2182
2183        // Server should be able to receive the fragmented message and control frames
2184        let mut received_message = None;
2185        let mut received_ping = false;
2186
2187        for _ in 0..2 {
2188            let frame = server.next_frame().await.expect("Failed to receive frame");
2189
2190            match frame.opcode() {
2191                OpCode::Binary => {
2192                    assert!(frame.is_fin());
2193                    received_message = Some(frame.payload().to_vec());
2194                }
2195                OpCode::Ping => {
2196                    received_ping = true;
2197                }
2198                _ => panic!("Unexpected frame type: {:?}", frame.opcode()),
2199            }
2200        }
2201
2202        assert!(received_message.is_some(), "Message not received");
2203        assert!(received_ping, "Ping not received");
2204
2205        let received = String::from_utf8(received_message.unwrap())
2206            .expect("Invalid UTF-8 in received payload");
2207
2208        assert_eq!(
2209            received, original_payload,
2210            "Compressed fragmented payload mismatch"
2211        );
2212    }
2213}