sockudo_ws/stream/
websocket.rs

1//! WebSocket stream implementation
2//!
3//! This module provides the main `WebSocketStream` type.
4
5use std::io;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use bytes::BytesMut;
10use futures_core::Stream;
11use futures_sink::Sink;
12use pin_project_lite::pin_project;
13use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14
15use crate::Config;
16use crate::cork::CorkBuffer;
17use crate::error::{CloseReason, Error, Result};
18use crate::protocol::{Message, Protocol, Role};
19
20/// Default high water mark for backpressure (64KB)
21const DEFAULT_HIGH_WATER_MARK: usize = 64 * 1024;
22
23/// Default low water mark for backpressure (16KB)
24const DEFAULT_LOW_WATER_MARK: usize = 16 * 1024;
25
26pin_project! {
27    /// A WebSocket stream over an async transport
28    ///
29    /// This type implements both `Stream<Item = Result<Message>>` for receiving
30    /// and `Sink<Message>` for sending messages.
31    ///
32    /// # Backpressure
33    ///
34    /// The stream supports backpressure monitoring through `is_backpressured()` and
35    /// `write_buffer_len()` methods. When the write buffer exceeds the high water mark,
36    /// producers should pause sending until the buffer drains below the low water mark.
37    ///
38    /// # Example
39    ///
40    /// ```ignore
41    /// use futures_util::{SinkExt, StreamExt};
42    /// use sockudo_ws::WebSocketStream;
43    ///
44    /// async fn handle(mut ws: WebSocketStream<TcpStream>) {
45    ///     while let Some(msg) = ws.next().await {
46    ///         match msg {
47    ///             Ok(Message::Text(text)) => {
48    ///                 // Check backpressure before sending
49    ///                 if ws.is_backpressured() {
50    ///                     ws.flush().await?;
51    ///                 }
52    ///                 ws.send(Message::Text(text)).await?;
53    ///             }
54    ///             Ok(Message::Close(_)) => break,
55    ///             _ => {}
56    ///         }
57    ///     }
58    /// }
59    /// ```
60    pub struct WebSocketStream<S> {
61        #[pin]
62        inner: S,
63        protocol: Protocol,
64        read_buf: BytesMut,
65        write_buf: CorkBuffer,
66        state: StreamState,
67        config: Config,
68        // Pending messages from last process() call
69        pending_messages: Vec<Message>,
70        pending_index: usize,
71        // Backpressure thresholds
72        high_water_mark: usize,
73        low_water_mark: usize,
74    }
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78enum StreamState {
79    /// Normal operation
80    Open,
81    /// Flushing write buffer
82    Flushing,
83    /// Close frame sent
84    CloseSent,
85    /// Connection closed
86    Closed,
87}
88
89impl<S> WebSocketStream<S>
90where
91    S: AsyncRead + AsyncWrite + Unpin,
92{
93    /// Create a new WebSocket stream from an already-upgraded connection
94    pub fn from_raw(inner: S, role: Role, config: Config) -> Self {
95        let protocol = Protocol::new(role, config.max_frame_size, config.max_message_size);
96
97        Self {
98            inner,
99            protocol,
100            read_buf: BytesMut::with_capacity(crate::RECV_BUFFER_SIZE),
101            write_buf: CorkBuffer::with_capacity(config.write_buffer_size),
102            state: StreamState::Open,
103            config,
104            pending_messages: Vec::new(),
105            pending_index: 0,
106            high_water_mark: DEFAULT_HIGH_WATER_MARK,
107            low_water_mark: DEFAULT_LOW_WATER_MARK,
108        }
109    }
110
111    /// Create a server-side WebSocket stream
112    pub fn server(inner: S, config: Config) -> Self {
113        Self::from_raw(inner, Role::Server, config)
114    }
115
116    /// Create a client-side WebSocket stream
117    pub fn client(inner: S, config: Config) -> Self {
118        Self::from_raw(inner, Role::Client, config)
119    }
120
121    /// Get a reference to the underlying stream
122    pub fn get_ref(&self) -> &S {
123        &self.inner
124    }
125
126    /// Get a mutable reference to the underlying stream
127    pub fn get_mut(&mut self) -> &mut S {
128        &mut self.inner
129    }
130
131    /// Consume the WebSocket stream and return the underlying stream
132    pub fn into_inner(self) -> S {
133        self.inner
134    }
135
136    /// Check if the connection is closed
137    pub fn is_closed(&self) -> bool {
138        self.state == StreamState::Closed
139    }
140
141    // ========================================================================
142    // Backpressure API
143    // ========================================================================
144
145    /// Check if the write buffer is backpressured
146    ///
147    /// Returns `true` when the write buffer has exceeded the high water mark.
148    /// Producers should pause sending new messages until `is_write_buffer_low()`
149    /// returns `true` or until the buffer is flushed.
150    ///
151    /// # Example
152    ///
153    /// ```ignore
154    /// if ws.is_backpressured() {
155    ///     // Wait for buffer to drain before sending more
156    ///     ws.flush().await?;
157    /// }
158    /// ```
159    #[inline]
160    pub fn is_backpressured(&self) -> bool {
161        self.write_buf.pending_bytes() > self.high_water_mark
162    }
163
164    /// Check if the write buffer is below the low water mark
165    ///
166    /// Returns `true` when the write buffer has drained below the low water mark.
167    /// This can be used to resume sending after backpressure was detected.
168    #[inline]
169    pub fn is_write_buffer_low(&self) -> bool {
170        self.write_buf.pending_bytes() <= self.low_water_mark
171    }
172
173    /// Get the current write buffer size in bytes
174    ///
175    /// Useful for monitoring and debugging backpressure issues.
176    #[inline]
177    pub fn write_buffer_len(&self) -> usize {
178        self.write_buf.pending_bytes()
179    }
180
181    /// Get the current read buffer size in bytes
182    ///
183    /// Useful for monitoring memory usage and debugging.
184    #[inline]
185    pub fn read_buffer_len(&self) -> usize {
186        self.read_buf.len()
187    }
188
189    /// Set the high water mark for backpressure
190    ///
191    /// When the write buffer exceeds this threshold, `is_backpressured()` returns `true`.
192    /// Default is 64KB.
193    #[inline]
194    pub fn set_high_water_mark(&mut self, size: usize) {
195        self.high_water_mark = size;
196    }
197
198    /// Set the low water mark for backpressure
199    ///
200    /// When the write buffer drops below this threshold, `is_write_buffer_low()` returns `true`.
201    /// Default is 16KB.
202    #[inline]
203    pub fn set_low_water_mark(&mut self, size: usize) {
204        self.low_water_mark = size;
205    }
206
207    /// Get the current high water mark
208    #[inline]
209    pub fn high_water_mark(&self) -> usize {
210        self.high_water_mark
211    }
212
213    /// Get the current low water mark
214    #[inline]
215    pub fn low_water_mark(&self) -> usize {
216        self.low_water_mark
217    }
218
219    /// Send a close frame
220    pub async fn close(&mut self, code: u16, reason: &str) -> Result<()> {
221        if self.state != StreamState::Open {
222            return Ok(());
223        }
224
225        let close = Message::Close(Some(CloseReason::new(code, reason)));
226        self.protocol
227            .encode_message(&close, self.write_buf.buffer_mut())?;
228        self.state = StreamState::CloseSent;
229
230        // Flush the close frame
231        self.flush_write_buf().await?;
232        Ok(())
233    }
234
235    /// Flush the write buffer to the underlying stream
236    async fn flush_write_buf(&mut self) -> Result<()> {
237        use tokio::io::AsyncWriteExt;
238
239        while self.write_buf.has_data() {
240            let slices = self.write_buf.get_write_slices();
241            if slices.is_empty() {
242                break;
243            }
244
245            let n = self.inner.write_vectored(&slices).await?;
246            if n == 0 {
247                return Err(Error::ConnectionClosed);
248            }
249            self.write_buf.consume(n);
250        }
251
252        self.inner.flush().await?;
253        Ok(())
254    }
255
256    /// Read more data from the underlying stream
257    fn poll_read_more(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
258        let this = self.project();
259
260        // Ensure we have space in the buffer
261        if this.read_buf.capacity() - this.read_buf.len() < 4096 {
262            this.read_buf.reserve(8192);
263        }
264
265        // Get a slice of uninitialized memory
266        let buf_len = this.read_buf.len();
267        let buf_cap = this.read_buf.capacity();
268
269        // SAFETY: We're extending into the spare capacity
270        unsafe {
271            this.read_buf.set_len(buf_cap);
272        }
273
274        let mut read_buf = ReadBuf::new(&mut this.read_buf[buf_len..]);
275
276        match this.inner.poll_read(cx, &mut read_buf) {
277            Poll::Ready(Ok(())) => {
278                let n = read_buf.filled().len();
279                unsafe {
280                    this.read_buf.set_len(buf_len + n);
281                }
282                if n == 0 {
283                    Poll::Ready(Ok(0))
284                } else {
285                    Poll::Ready(Ok(n))
286                }
287            }
288            Poll::Ready(Err(e)) => {
289                unsafe {
290                    this.read_buf.set_len(buf_len);
291                }
292                Poll::Ready(Err(e))
293            }
294            Poll::Pending => {
295                unsafe {
296                    this.read_buf.set_len(buf_len);
297                }
298                Poll::Pending
299            }
300        }
301    }
302
303    /// Process read buffer and extract messages
304    fn process_read_buf(&mut self) -> Result<()> {
305        if self.read_buf.is_empty() {
306            return Ok(());
307        }
308
309        let messages = self.protocol.process(&mut self.read_buf)?;
310
311        if !messages.is_empty() {
312            self.pending_messages = messages;
313            self.pending_index = 0;
314        }
315
316        Ok(())
317    }
318
319    /// Get the next pending message
320    fn next_pending_message(&mut self) -> Option<Message> {
321        if self.pending_index < self.pending_messages.len() {
322            let msg = self.pending_messages[self.pending_index].clone();
323            self.pending_index += 1;
324
325            // Clear when all consumed
326            if self.pending_index >= self.pending_messages.len() {
327                self.pending_messages.clear();
328                self.pending_index = 0;
329            }
330
331            Some(msg)
332        } else {
333            None
334        }
335    }
336}
337
338impl<S> Stream for WebSocketStream<S>
339where
340    S: AsyncRead + AsyncWrite + Unpin,
341{
342    type Item = Result<Message>;
343
344    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
345        loop {
346            // Check for connection closed
347            if self.state == StreamState::Closed {
348                return Poll::Ready(None);
349            }
350
351            // First, return any pending messages
352            if let Some(msg) = self.as_mut().get_mut().next_pending_message() {
353                // Handle control frames
354                match &msg {
355                    Message::Ping(data) => {
356                        // Queue pong response
357                        let this = self.as_mut().get_mut();
358                        this.protocol.encode_pong(data, this.write_buf.buffer_mut());
359                    }
360                    Message::Close(reason) => {
361                        let this = self.as_mut().get_mut();
362                        if this.state == StreamState::Open {
363                            // Send close response
364                            this.protocol
365                                .encode_close_response(this.write_buf.buffer_mut());
366                            this.state = StreamState::Closed;
367                        }
368                        return Poll::Ready(Some(Ok(Message::Close(reason.clone()))));
369                    }
370                    _ => {}
371                }
372
373                return Poll::Ready(Some(Ok(msg)));
374            }
375
376            // Try to read more data
377            match self.as_mut().poll_read_more(cx) {
378                Poll::Ready(Ok(0)) => {
379                    // EOF - connection closed
380                    self.as_mut().get_mut().state = StreamState::Closed;
381                    return Poll::Ready(None);
382                }
383                Poll::Ready(Ok(_n)) => {
384                    // Process the new data
385                    match self.as_mut().get_mut().process_read_buf() {
386                        Ok(()) => continue, // Loop to check for messages
387                        Err(e) => return Poll::Ready(Some(Err(e))),
388                    }
389                }
390                Poll::Ready(Err(e)) => {
391                    return Poll::Ready(Some(Err(e.into())));
392                }
393                Poll::Pending => {
394                    // No more data available right now
395                    return Poll::Pending;
396                }
397            }
398        }
399    }
400}
401
402impl<S> Sink<Message> for WebSocketStream<S>
403where
404    S: AsyncRead + AsyncWrite + Unpin,
405{
406    type Error = Error;
407
408    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
409        if self.state == StreamState::Closed {
410            return Poll::Ready(Err(Error::ConnectionClosed));
411        }
412        Poll::Ready(Ok(()))
413    }
414
415    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<()> {
416        let this = self.get_mut();
417
418        if this.state == StreamState::Closed {
419            return Err(Error::ConnectionClosed);
420        }
421
422        // Track close frame sending
423        if item.is_close() {
424            this.state = StreamState::CloseSent;
425        }
426
427        // Encode message into write buffer
428        this.protocol
429            .encode_message(&item, this.write_buf.buffer_mut())?;
430        Ok(())
431    }
432
433    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
434        let this = self.as_mut().get_mut();
435
436        // Write all pending data
437        while this.write_buf.has_data() {
438            let slices = this.write_buf.get_write_slices();
439            if slices.is_empty() {
440                break;
441            }
442
443            match Pin::new(&mut this.inner).poll_write_vectored(cx, &slices) {
444                Poll::Ready(Ok(0)) => {
445                    return Poll::Ready(Err(Error::ConnectionClosed));
446                }
447                Poll::Ready(Ok(n)) => {
448                    this.write_buf.consume(n);
449                }
450                Poll::Ready(Err(e)) => {
451                    return Poll::Ready(Err(e.into()));
452                }
453                Poll::Pending => {
454                    return Poll::Pending;
455                }
456            }
457        }
458
459        // Flush underlying stream
460        match Pin::new(&mut self.as_mut().get_mut().inner).poll_flush(cx) {
461            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
462            Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
463            Poll::Pending => Poll::Pending,
464        }
465    }
466
467    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
468        // Send close frame if not already sent
469        if self.state == StreamState::Open {
470            let close = Message::Close(Some(CloseReason::new(1000, "")));
471            if let Err(e) = self.as_mut().start_send(close) {
472                return Poll::Ready(Err(e));
473            }
474        }
475
476        // Flush pending data
477        match self.as_mut().poll_flush(cx) {
478            Poll::Ready(Ok(())) => {}
479            Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
480            Poll::Pending => return Poll::Pending,
481        }
482
483        // Shutdown the underlying stream
484        match Pin::new(&mut self.as_mut().get_mut().inner).poll_shutdown(cx) {
485            Poll::Ready(Ok(())) => {
486                self.as_mut().get_mut().state = StreamState::Closed;
487                Poll::Ready(Ok(()))
488            }
489            Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
490            Poll::Pending => Poll::Pending,
491        }
492    }
493}
494
495/// Builder for WebSocket streams
496pub struct WebSocketStreamBuilder {
497    config: Config,
498    role: Role,
499    high_water_mark: usize,
500    low_water_mark: usize,
501}
502
503impl WebSocketStreamBuilder {
504    /// Create a new builder with default configuration
505    pub fn new() -> Self {
506        Self {
507            config: Config::default(),
508            role: Role::Server,
509            high_water_mark: DEFAULT_HIGH_WATER_MARK,
510            low_water_mark: DEFAULT_LOW_WATER_MARK,
511        }
512    }
513
514    /// Set the endpoint role
515    pub fn role(mut self, role: Role) -> Self {
516        self.role = role;
517        self
518    }
519
520    /// Set the maximum message size
521    pub fn max_message_size(mut self, size: usize) -> Self {
522        self.config.max_message_size = size;
523        self
524    }
525
526    /// Set the maximum frame size
527    pub fn max_frame_size(mut self, size: usize) -> Self {
528        self.config.max_frame_size = size;
529        self
530    }
531
532    /// Set the write buffer size
533    pub fn write_buffer_size(mut self, size: usize) -> Self {
534        self.config.write_buffer_size = size;
535        self
536    }
537
538    /// Set the high water mark for backpressure
539    ///
540    /// When the write buffer exceeds this threshold, `is_backpressured()` returns `true`.
541    /// Default is 64KB.
542    pub fn high_water_mark(mut self, size: usize) -> Self {
543        self.high_water_mark = size;
544        self
545    }
546
547    /// Set the low water mark for backpressure
548    ///
549    /// When the write buffer drops below this threshold, `is_write_buffer_low()` returns `true`.
550    /// Default is 16KB.
551    pub fn low_water_mark(mut self, size: usize) -> Self {
552        self.low_water_mark = size;
553        self
554    }
555
556    /// Build the WebSocket stream
557    pub fn build<S>(self, stream: S) -> WebSocketStream<S>
558    where
559        S: AsyncRead + AsyncWrite + Unpin,
560    {
561        let mut ws = WebSocketStream::from_raw(stream, self.role, self.config);
562        ws.high_water_mark = self.high_water_mark;
563        ws.low_water_mark = self.low_water_mark;
564        ws
565    }
566}
567
568impl Default for WebSocketStreamBuilder {
569    fn default() -> Self {
570        Self::new()
571    }
572}
573
574// ============================================================================
575// Split Stream Implementation - LOCK-FREE EDITION 🚀
576// ============================================================================
577//
578// This implementation uses tokio::io::split() for true concurrent I/O:
579// - ✅ Zero mutex contention
580// - ✅ Reader and writer operate 100% independently
581// - ✅ Native OS-level efficiency
582// - ✅ No shared lock on the underlying transport
583//
584// Control frames (Ping/Pong/Close) are coordinated via an mpsc channel
585// from the reader to the writer, allowing the reader to request responses
586// without blocking.
587
588use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
589use tokio::sync::mpsc;
590
591/// Control frame requests sent from reader to writer
592#[derive(Debug, Clone)]
593enum ControlRequest {
594    /// Send a Pong in response to a Ping
595    Pong(bytes::Bytes),
596    /// Send a Close response
597    CloseResponse,
598}
599
600/// The read half of a split WebSocket stream
601///
602/// Created by calling `split()` on a `WebSocketStream`.
603/// This half owns the read side of the TCP stream and can operate
604/// completely independently from the write half.
605pub struct SplitReader<S> {
606    /// Read half of the underlying stream (no lock!)
607    reader: ReadHalf<S>,
608    /// Protocol for decoding
609    protocol: Protocol,
610    /// Read buffer
611    read_buf: BytesMut,
612    /// Pending messages from last decode
613    pending_messages: Vec<Message>,
614    pending_index: usize,
615    /// Channel to send control frame requests to writer
616    control_tx: mpsc::UnboundedSender<ControlRequest>,
617    /// Connection state
618    closed: bool,
619}
620
621/// The write half of a split WebSocket stream
622///
623/// Created by calling `split()` on a `WebSocketStream`.
624/// This half owns the write side of the TCP stream and can operate
625/// completely independently from the read half.
626pub struct SplitWriter<S> {
627    /// Write half of the underlying stream (no lock!)
628    writer: WriteHalf<S>,
629    /// Protocol for encoding
630    protocol: Protocol,
631    /// Write buffer for encoding
632    write_buf: BytesMut,
633    /// Channel to receive control frame requests from reader
634    control_rx: mpsc::UnboundedReceiver<ControlRequest>,
635    /// Connection state
636    closed: bool,
637}
638
639impl<S> WebSocketStream<S>
640where
641    S: AsyncRead + AsyncWrite + Unpin,
642{
643    /// Split the WebSocket stream into separate read and write halves
644    ///
645    /// This allows TRUE concurrent reading and writing from different tasks
646    /// with ZERO lock contention. The underlying TCP stream is split at the
647    /// OS level for maximum performance.
648    ///
649    /// # Example
650    ///
651    /// ```ignore
652    /// let (mut reader, mut writer) = ws.split();
653    ///
654    /// // Read in one task - NEVER blocks writer
655    /// tokio::spawn(async move {
656    ///     while let Some(msg) = reader.next().await {
657    ///         println!("Got: {:?}", msg);
658    ///     }
659    /// });
660    ///
661    /// // Write in another - NEVER blocks reader
662    /// writer.send(Message::Text("Hello".into())).await?;
663    /// ```
664    pub fn split(self) -> (SplitReader<S>, SplitWriter<S>) {
665        // Split the underlying transport at the OS level
666        let (reader, writer) = tokio::io::split(self.inner);
667
668        // Create channel for control frame coordination
669        let (control_tx, control_rx) = mpsc::unbounded_channel();
670
671        // Clone the protocol for both halves (cheap - just config)
672        let reader_protocol = Protocol::new(
673            self.protocol.role,
674            self.config.max_frame_size,
675            self.config.max_message_size,
676        );
677        let writer_protocol = self.protocol;
678
679        (
680            SplitReader {
681                reader,
682                protocol: reader_protocol,
683                read_buf: self.read_buf,
684                pending_messages: self.pending_messages,
685                pending_index: self.pending_index,
686                control_tx,
687                closed: self.state == StreamState::Closed,
688            },
689            SplitWriter {
690                writer,
691                protocol: writer_protocol,
692                write_buf: BytesMut::with_capacity(1024),
693                control_rx,
694                closed: self.state == StreamState::Closed,
695            },
696        )
697    }
698}
699
700impl<S> SplitReader<S>
701where
702    S: AsyncRead + AsyncWrite + Unpin,
703{
704    /// Receive the next message
705    ///
706    /// Returns `None` when the connection is closed.
707    /// This method NEVER blocks the writer - true concurrent I/O!
708    pub async fn next(&mut self) -> Option<Result<Message>> {
709        loop {
710            // Check for connection closed
711            if self.closed {
712                return None;
713            }
714
715            // Return any pending messages first
716            if self.pending_index < self.pending_messages.len() {
717                let msg = self.pending_messages[self.pending_index].clone();
718                self.pending_index += 1;
719
720                if self.pending_index >= self.pending_messages.len() {
721                    self.pending_messages.clear();
722                    self.pending_index = 0;
723                }
724
725                // Handle control frames - send requests to writer via channel
726                match &msg {
727                    Message::Ping(data) => {
728                        // Request writer to send pong (non-blocking!)
729                        let _ = self.control_tx.send(ControlRequest::Pong(data.clone()));
730                        // Continue to next message (user doesn't see Ping)
731                        continue;
732                    }
733                    Message::Close(reason) => {
734                        if !self.closed {
735                            // Request writer to send close response
736                            let _ = self.control_tx.send(ControlRequest::CloseResponse);
737                            self.closed = true;
738                        }
739                        return Some(Ok(Message::Close(reason.clone())));
740                    }
741                    Message::Pong(_) => {
742                        // User doesn't typically need to see Pong
743                        continue;
744                    }
745                    _ => {}
746                }
747
748                return Some(Ok(msg));
749            }
750
751            // Need to read more data - NO LOCK HERE!
752            // Reserve space if needed
753            if self.read_buf.capacity() - self.read_buf.len() < 4096 {
754                self.read_buf.reserve(8192);
755            }
756
757            match self.reader.read_buf(&mut self.read_buf).await {
758                Ok(0) => {
759                    // EOF - connection closed
760                    self.closed = true;
761                    return None;
762                }
763                Ok(_n) => {
764                    // Process the new data
765                    match self.protocol.process(&mut self.read_buf) {
766                        Ok(messages) => {
767                            if !messages.is_empty() {
768                                self.pending_messages = messages;
769                                self.pending_index = 0;
770                            }
771                        }
772                        Err(e) => return Some(Err(e)),
773                    }
774                    // Continue loop to check for messages
775                }
776                Err(e) => {
777                    return Some(Err(e.into()));
778                }
779            }
780        }
781    }
782
783    /// Check if the connection is closed
784    pub fn is_closed(&self) -> bool {
785        self.closed
786    }
787}
788
789impl<S> SplitWriter<S>
790where
791    S: AsyncRead + AsyncWrite + Unpin,
792{
793    /// Send a message
794    ///
795    /// This method NEVER blocks the reader - true concurrent I/O!
796    pub async fn send(&mut self, msg: Message) -> Result<()> {
797        if self.closed {
798            return Err(Error::ConnectionClosed);
799        }
800
801        // Process any pending control frame requests from reader first
802        self.process_control_requests().await?;
803
804        if msg.is_close() {
805            self.closed = true;
806        }
807
808        // Encode message - NO LOCK HERE!
809        self.write_buf.clear();
810        self.protocol.encode_message(&msg, &mut self.write_buf)?;
811
812        // Write to the underlying stream - NO LOCK HERE!
813        self.writer.write_all(&self.write_buf).await?;
814        self.writer.flush().await?;
815        Ok(())
816    }
817
818    /// Process control frame requests from the reader
819    async fn process_control_requests(&mut self) -> Result<()> {
820        // Drain all pending control requests
821        while let Ok(req) = self.control_rx.try_recv() {
822            self.write_buf.clear();
823
824            match req {
825                ControlRequest::Pong(data) => {
826                    self.protocol.encode_pong(&data, &mut self.write_buf);
827                }
828                ControlRequest::CloseResponse => {
829                    self.protocol.encode_close_response(&mut self.write_buf);
830                    self.closed = true;
831                }
832            }
833
834            if !self.write_buf.is_empty() {
835                self.writer.write_all(&self.write_buf).await?;
836            }
837        }
838
839        Ok(())
840    }
841
842    /// Send a text message
843    pub async fn send_text(&mut self, text: impl Into<String>) -> Result<()> {
844        self.send(Message::text(text)).await
845    }
846
847    /// Send a binary message
848    pub async fn send_binary(&mut self, data: bytes::Bytes) -> Result<()> {
849        self.send(Message::Binary(data)).await
850    }
851
852    /// Send a close frame
853    pub async fn close(&mut self, code: u16, reason: &str) -> Result<()> {
854        self.send(Message::Close(Some(CloseReason::new(code, reason))))
855            .await
856    }
857
858    /// Check if the connection is closed
859    pub fn is_closed(&self) -> bool {
860        self.closed
861    }
862
863    /// Flush any pending control responses
864    pub async fn flush(&mut self) -> Result<()> {
865        self.process_control_requests().await?;
866        self.writer.flush().await.map_err(Into::into)
867    }
868}
869
870// ============================================================================
871// Compressed WebSocket Stream (permessage-deflate)
872// ============================================================================
873
874#[cfg(feature = "permessage-deflate")]
875pin_project! {
876    /// A WebSocket stream with permessage-deflate compression (RFC 7692)
877    ///
878    /// This type mirrors `WebSocketStream` but uses `CompressedProtocol` for
879    /// automatic compression/decompression of messages.
880    pub struct CompressedWebSocketStream<S> {
881        #[pin]
882        inner: S,
883        protocol: crate::protocol::CompressedProtocol,
884        read_buf: BytesMut,
885        write_buf: CorkBuffer,
886        state: StreamState,
887        config: Config,
888        pending_messages: Vec<Message>,
889        pending_index: usize,
890        high_water_mark: usize,
891        low_water_mark: usize,
892    }
893}
894
895#[cfg(feature = "permessage-deflate")]
896impl<S> CompressedWebSocketStream<S>
897where
898    S: AsyncRead + AsyncWrite + Unpin,
899{
900    /// Create a new compressed WebSocket stream for server role
901    pub fn server(inner: S, config: Config, deflate_config: crate::deflate::DeflateConfig) -> Self {
902        let protocol = crate::protocol::CompressedProtocol::server(
903            config.max_frame_size,
904            config.max_message_size,
905            deflate_config,
906        );
907
908        Self {
909            inner,
910            protocol,
911            read_buf: BytesMut::with_capacity(crate::RECV_BUFFER_SIZE),
912            write_buf: CorkBuffer::with_capacity(config.write_buffer_size),
913            state: StreamState::Open,
914            config,
915            pending_messages: Vec::new(),
916            pending_index: 0,
917            high_water_mark: DEFAULT_HIGH_WATER_MARK,
918            low_water_mark: DEFAULT_LOW_WATER_MARK,
919        }
920    }
921
922    /// Create a new compressed WebSocket stream for client role
923    pub fn client(inner: S, config: Config, deflate_config: crate::deflate::DeflateConfig) -> Self {
924        let protocol = crate::protocol::CompressedProtocol::client(
925            config.max_frame_size,
926            config.max_message_size,
927            deflate_config,
928        );
929
930        Self {
931            inner,
932            protocol,
933            read_buf: BytesMut::with_capacity(crate::RECV_BUFFER_SIZE),
934            write_buf: CorkBuffer::with_capacity(config.write_buffer_size),
935            state: StreamState::Open,
936            config,
937            pending_messages: Vec::new(),
938            pending_index: 0,
939            high_water_mark: DEFAULT_HIGH_WATER_MARK,
940            low_water_mark: DEFAULT_LOW_WATER_MARK,
941        }
942    }
943
944    /// Check if the connection is closed
945    #[inline]
946    pub fn is_closed(&self) -> bool {
947        self.state == StreamState::Closed || self.protocol.is_closed()
948    }
949
950    /// Check if backpressure should be applied
951    #[inline]
952    pub fn is_backpressured(&self) -> bool {
953        self.write_buf.pending_bytes() > self.high_water_mark
954    }
955
956    /// Get the current write buffer length
957    #[inline]
958    pub fn write_buffer_len(&self) -> usize {
959        self.write_buf.pending_bytes()
960    }
961
962    /// Send a close frame
963    pub async fn close(&mut self, code: u16, reason: &str) -> Result<()> {
964        if self.state != StreamState::Open {
965            return Ok(());
966        }
967
968        let close = Message::Close(Some(CloseReason::new(code, reason)));
969        self.protocol
970            .encode_message(&close, self.write_buf.buffer_mut())?;
971        self.state = StreamState::CloseSent;
972
973        self.flush_write_buf().await?;
974        Ok(())
975    }
976
977    /// Flush the write buffer to the underlying stream
978    async fn flush_write_buf(&mut self) -> Result<()> {
979        use tokio::io::AsyncWriteExt;
980
981        while self.write_buf.has_data() {
982            let slices = self.write_buf.get_write_slices();
983            if slices.is_empty() {
984                break;
985            }
986
987            let n = self.inner.write_vectored(&slices).await?;
988            if n == 0 {
989                return Err(Error::ConnectionClosed);
990            }
991            self.write_buf.consume(n);
992        }
993
994        self.inner.flush().await?;
995        Ok(())
996    }
997
998    /// Read more data from the underlying stream
999    fn poll_read_more(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
1000        let this = self.project();
1001
1002        if this.read_buf.capacity() - this.read_buf.len() < 4096 {
1003            this.read_buf.reserve(8192);
1004        }
1005
1006        let buf_len = this.read_buf.len();
1007        let buf_cap = this.read_buf.capacity();
1008
1009        unsafe {
1010            this.read_buf.set_len(buf_cap);
1011        }
1012
1013        let mut read_buf = ReadBuf::new(&mut this.read_buf[buf_len..]);
1014
1015        match this.inner.poll_read(cx, &mut read_buf) {
1016            Poll::Ready(Ok(())) => {
1017                let n = read_buf.filled().len();
1018                unsafe {
1019                    this.read_buf.set_len(buf_len + n);
1020                }
1021                if n == 0 {
1022                    Poll::Ready(Ok(0))
1023                } else {
1024                    Poll::Ready(Ok(n))
1025                }
1026            }
1027            Poll::Ready(Err(e)) => {
1028                unsafe {
1029                    this.read_buf.set_len(buf_len);
1030                }
1031                Poll::Ready(Err(e))
1032            }
1033            Poll::Pending => {
1034                unsafe {
1035                    this.read_buf.set_len(buf_len);
1036                }
1037                Poll::Pending
1038            }
1039        }
1040    }
1041
1042    /// Process read buffer and extract messages
1043    fn process_read_buf(&mut self) -> Result<()> {
1044        if self.read_buf.is_empty() {
1045            return Ok(());
1046        }
1047
1048        let messages = self.protocol.process(&mut self.read_buf)?;
1049
1050        if !messages.is_empty() {
1051            self.pending_messages = messages;
1052            self.pending_index = 0;
1053        }
1054
1055        Ok(())
1056    }
1057
1058    /// Get the next pending message
1059    fn next_pending_message(&mut self) -> Option<Message> {
1060        if self.pending_index < self.pending_messages.len() {
1061            let msg = self.pending_messages[self.pending_index].clone();
1062            self.pending_index += 1;
1063
1064            if self.pending_index >= self.pending_messages.len() {
1065                self.pending_messages.clear();
1066                self.pending_index = 0;
1067            }
1068
1069            Some(msg)
1070        } else {
1071            None
1072        }
1073    }
1074}
1075
1076#[cfg(feature = "permessage-deflate")]
1077impl<S> Stream for CompressedWebSocketStream<S>
1078where
1079    S: AsyncRead + AsyncWrite + Unpin,
1080{
1081    type Item = Result<Message>;
1082
1083    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1084        loop {
1085            if self.state == StreamState::Closed {
1086                return Poll::Ready(None);
1087            }
1088
1089            if let Some(msg) = self.as_mut().get_mut().next_pending_message() {
1090                match &msg {
1091                    Message::Ping(data) => {
1092                        let this = self.as_mut().get_mut();
1093                        this.protocol.encode_pong(data, this.write_buf.buffer_mut());
1094                    }
1095                    Message::Close(reason) => {
1096                        let this = self.as_mut().get_mut();
1097                        if this.state == StreamState::Open {
1098                            this.protocol
1099                                .encode_close_response(this.write_buf.buffer_mut());
1100                            this.state = StreamState::Closed;
1101                        }
1102                        return Poll::Ready(Some(Ok(Message::Close(reason.clone()))));
1103                    }
1104                    _ => {}
1105                }
1106
1107                return Poll::Ready(Some(Ok(msg)));
1108            }
1109
1110            match self.as_mut().poll_read_more(cx) {
1111                Poll::Ready(Ok(0)) => {
1112                    self.as_mut().get_mut().state = StreamState::Closed;
1113                    return Poll::Ready(None);
1114                }
1115                Poll::Ready(Ok(_n)) => match self.as_mut().get_mut().process_read_buf() {
1116                    Ok(()) => continue,
1117                    Err(e) => return Poll::Ready(Some(Err(e))),
1118                },
1119                Poll::Ready(Err(e)) => {
1120                    return Poll::Ready(Some(Err(e.into())));
1121                }
1122                Poll::Pending => {
1123                    return Poll::Pending;
1124                }
1125            }
1126        }
1127    }
1128}
1129
1130#[cfg(feature = "permessage-deflate")]
1131impl<S> Sink<Message> for CompressedWebSocketStream<S>
1132where
1133    S: AsyncRead + AsyncWrite + Unpin,
1134{
1135    type Error = Error;
1136
1137    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
1138        if self.state == StreamState::Closed {
1139            return Poll::Ready(Err(Error::ConnectionClosed));
1140        }
1141        Poll::Ready(Ok(()))
1142    }
1143
1144    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<()> {
1145        let this = self.get_mut();
1146
1147        if this.state == StreamState::Closed {
1148            return Err(Error::ConnectionClosed);
1149        }
1150
1151        if item.is_close() {
1152            this.state = StreamState::CloseSent;
1153        }
1154
1155        this.protocol
1156            .encode_message(&item, this.write_buf.buffer_mut())?;
1157        Ok(())
1158    }
1159
1160    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
1161        let this = self.as_mut().get_mut();
1162
1163        while this.write_buf.has_data() {
1164            let slices = this.write_buf.get_write_slices();
1165            if slices.is_empty() {
1166                break;
1167            }
1168
1169            match Pin::new(&mut this.inner).poll_write_vectored(cx, &slices) {
1170                Poll::Ready(Ok(0)) => {
1171                    return Poll::Ready(Err(Error::ConnectionClosed));
1172                }
1173                Poll::Ready(Ok(n)) => {
1174                    this.write_buf.consume(n);
1175                }
1176                Poll::Ready(Err(e)) => {
1177                    return Poll::Ready(Err(e.into()));
1178                }
1179                Poll::Pending => {
1180                    return Poll::Pending;
1181                }
1182            }
1183        }
1184
1185        match Pin::new(&mut self.as_mut().get_mut().inner).poll_flush(cx) {
1186            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
1187            Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
1188            Poll::Pending => Poll::Pending,
1189        }
1190    }
1191
1192    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
1193        if self.state == StreamState::Open {
1194            let close = Message::Close(Some(CloseReason::new(1000, "")));
1195            if let Err(e) = self.as_mut().start_send(close) {
1196                return Poll::Ready(Err(e));
1197            }
1198        }
1199
1200        match self.as_mut().poll_flush(cx) {
1201            Poll::Ready(Ok(())) => {}
1202            Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
1203            Poll::Pending => return Poll::Pending,
1204        }
1205
1206        match Pin::new(&mut self.as_mut().get_mut().inner).poll_shutdown(cx) {
1207            Poll::Ready(Ok(())) => {
1208                self.as_mut().get_mut().state = StreamState::Closed;
1209                Poll::Ready(Ok(()))
1210            }
1211            Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
1212            Poll::Pending => Poll::Pending,
1213        }
1214    }
1215}
1216
1217// ============================================================================
1218// Compressed Split Reader/Writer (permessage-deflate)
1219// ============================================================================
1220
1221/// The read half of a split compressed WebSocket stream
1222///
1223/// Created by calling `split()` on a `CompressedWebSocketStream`.
1224/// This half owns the read side of the TCP stream and can operate
1225/// completely independently from the write half.
1226#[cfg(feature = "permessage-deflate")]
1227pub struct CompressedSplitReader<S> {
1228    /// Read half of the underlying stream
1229    reader: ReadHalf<S>,
1230    /// Protocol for decoding with decompression
1231    protocol: crate::protocol::CompressedReaderProtocol,
1232    /// Read buffer
1233    read_buf: BytesMut,
1234    /// Pending messages from last decode
1235    pending_messages: Vec<Message>,
1236    pending_index: usize,
1237    /// Channel to send control frame requests to writer
1238    control_tx: mpsc::UnboundedSender<ControlRequest>,
1239    /// Connection state
1240    closed: bool,
1241}
1242
1243/// The write half of a split compressed WebSocket stream
1244///
1245/// Created by calling `split()` on a `CompressedWebSocketStream`.
1246/// This half owns the write side of the TCP stream and can operate
1247/// completely independently from the read half.
1248#[cfg(feature = "permessage-deflate")]
1249pub struct CompressedSplitWriter<S> {
1250    /// Write half of the underlying stream
1251    writer: WriteHalf<S>,
1252    /// Protocol for encoding with compression
1253    protocol: crate::protocol::CompressedWriterProtocol,
1254    /// Write buffer for encoding
1255    write_buf: BytesMut,
1256    /// Channel to receive control frame requests from reader
1257    control_rx: mpsc::UnboundedReceiver<ControlRequest>,
1258    /// Connection state
1259    closed: bool,
1260}
1261
1262#[cfg(feature = "permessage-deflate")]
1263impl<S> CompressedWebSocketStream<S>
1264where
1265    S: AsyncRead + AsyncWrite + Unpin,
1266{
1267    /// Split the compressed WebSocket stream into separate read and write halves
1268    ///
1269    /// This allows TRUE concurrent reading and writing from different tasks
1270    /// with ZERO lock contention. The underlying TCP stream is split at the
1271    /// OS level for maximum performance.
1272    ///
1273    /// Both halves maintain compression/decompression state independently:
1274    /// - Reader has the decoder for decompressing incoming messages
1275    /// - Writer has the encoder for compressing outgoing messages
1276    ///
1277    /// # Example
1278    ///
1279    /// ```ignore
1280    /// let (mut reader, mut writer) = compressed_ws.split();
1281    ///
1282    /// // Read in one task - NEVER blocks writer
1283    /// tokio::spawn(async move {
1284    ///     while let Some(msg) = reader.next().await {
1285    ///         println!("Got: {:?}", msg);
1286    ///     }
1287    /// });
1288    ///
1289    /// // Write in another - NEVER blocks reader
1290    /// writer.send(Message::Text("Hello".into())).await?;
1291    /// ```
1292    pub fn split(self) -> (CompressedSplitReader<S>, CompressedSplitWriter<S>) {
1293        // Split the underlying transport at the OS level
1294        let (reader, writer) = tokio::io::split(self.inner);
1295
1296        // Create channel for control frame coordination
1297        let (control_tx, control_rx) = mpsc::unbounded_channel();
1298
1299        // Split the protocol into reader and writer halves
1300        let (reader_protocol, writer_protocol) = self
1301            .protocol
1302            .split(self.config.max_frame_size, self.config.max_message_size);
1303
1304        (
1305            CompressedSplitReader {
1306                reader,
1307                protocol: reader_protocol,
1308                read_buf: self.read_buf,
1309                pending_messages: self.pending_messages,
1310                pending_index: self.pending_index,
1311                control_tx,
1312                closed: self.state == StreamState::Closed,
1313            },
1314            CompressedSplitWriter {
1315                writer,
1316                protocol: writer_protocol,
1317                write_buf: BytesMut::with_capacity(1024),
1318                control_rx,
1319                closed: self.state == StreamState::Closed,
1320            },
1321        )
1322    }
1323}
1324
1325#[cfg(feature = "permessage-deflate")]
1326impl<S> CompressedSplitReader<S>
1327where
1328    S: AsyncRead + AsyncWrite + Unpin,
1329{
1330    /// Receive the next message
1331    ///
1332    /// Returns `None` when the connection is closed.
1333    /// This method NEVER blocks the writer - true concurrent I/O!
1334    pub async fn next(&mut self) -> Option<Result<Message>> {
1335        loop {
1336            // Check for connection closed
1337            if self.closed {
1338                return None;
1339            }
1340
1341            // Return any pending messages first
1342            if self.pending_index < self.pending_messages.len() {
1343                let msg = self.pending_messages[self.pending_index].clone();
1344                self.pending_index += 1;
1345
1346                if self.pending_index >= self.pending_messages.len() {
1347                    self.pending_messages.clear();
1348                    self.pending_index = 0;
1349                }
1350
1351                // Handle control frames - send requests to writer via channel
1352                match &msg {
1353                    Message::Ping(data) => {
1354                        // Request writer to send pong (non-blocking!)
1355                        let _ = self.control_tx.send(ControlRequest::Pong(data.clone()));
1356                        // Continue to next message (user doesn't see Ping)
1357                        continue;
1358                    }
1359                    Message::Close(reason) => {
1360                        if !self.closed {
1361                            // Request writer to send close response
1362                            let _ = self.control_tx.send(ControlRequest::CloseResponse);
1363                            self.closed = true;
1364                        }
1365                        return Some(Ok(Message::Close(reason.clone())));
1366                    }
1367                    Message::Pong(_) => {
1368                        // User doesn't typically need to see Pong
1369                        continue;
1370                    }
1371                    _ => {}
1372                }
1373
1374                return Some(Ok(msg));
1375            }
1376
1377            // Need to read more data - NO LOCK HERE!
1378            // Reserve space if needed
1379            if self.read_buf.capacity() - self.read_buf.len() < 4096 {
1380                self.read_buf.reserve(8192);
1381            }
1382
1383            match self.reader.read_buf(&mut self.read_buf).await {
1384                Ok(0) => {
1385                    // EOF - connection closed
1386                    self.closed = true;
1387                    return None;
1388                }
1389                Ok(_n) => {
1390                    // Process the new data
1391                    match self.protocol.process(&mut self.read_buf) {
1392                        Ok(messages) => {
1393                            if !messages.is_empty() {
1394                                self.pending_messages = messages;
1395                                self.pending_index = 0;
1396                            }
1397                        }
1398                        Err(e) => return Some(Err(e)),
1399                    }
1400                    // Continue loop to check for messages
1401                }
1402                Err(e) => {
1403                    return Some(Err(e.into()));
1404                }
1405            }
1406        }
1407    }
1408
1409    /// Check if the connection is closed
1410    pub fn is_closed(&self) -> bool {
1411        self.closed
1412    }
1413}
1414
1415#[cfg(feature = "permessage-deflate")]
1416impl<S> CompressedSplitWriter<S>
1417where
1418    S: AsyncRead + AsyncWrite + Unpin,
1419{
1420    /// Send a message
1421    ///
1422    /// This method NEVER blocks the reader - true concurrent I/O!
1423    pub async fn send(&mut self, msg: Message) -> Result<()> {
1424        if self.closed {
1425            return Err(Error::ConnectionClosed);
1426        }
1427
1428        // Process any pending control frame requests from reader first
1429        self.process_control_requests().await?;
1430
1431        if msg.is_close() {
1432            self.closed = true;
1433        }
1434
1435        // Encode message with compression - NO LOCK HERE!
1436        self.write_buf.clear();
1437        self.protocol.encode_message(&msg, &mut self.write_buf)?;
1438
1439        // Write to the underlying stream - NO LOCK HERE!
1440        self.writer.write_all(&self.write_buf).await?;
1441        self.writer.flush().await?;
1442        Ok(())
1443    }
1444
1445    /// Process control frame requests from the reader
1446    async fn process_control_requests(&mut self) -> Result<()> {
1447        // Drain all pending control requests
1448        while let Ok(req) = self.control_rx.try_recv() {
1449            self.write_buf.clear();
1450
1451            match req {
1452                ControlRequest::Pong(data) => {
1453                    self.protocol.encode_pong(&data, &mut self.write_buf);
1454                }
1455                ControlRequest::CloseResponse => {
1456                    self.protocol.encode_close_response(&mut self.write_buf);
1457                    self.closed = true;
1458                }
1459            }
1460
1461            if !self.write_buf.is_empty() {
1462                self.writer.write_all(&self.write_buf).await?;
1463            }
1464        }
1465
1466        Ok(())
1467    }
1468
1469    /// Send a text message
1470    pub async fn send_text(&mut self, text: impl Into<String>) -> Result<()> {
1471        self.send(Message::text(text)).await
1472    }
1473
1474    /// Send a binary message
1475    pub async fn send_binary(&mut self, data: bytes::Bytes) -> Result<()> {
1476        self.send(Message::Binary(data)).await
1477    }
1478
1479    /// Send a close frame
1480    pub async fn close(&mut self, code: u16, reason: &str) -> Result<()> {
1481        self.send(Message::Close(Some(CloseReason::new(code, reason))))
1482            .await
1483    }
1484
1485    /// Check if the connection is closed
1486    pub fn is_closed(&self) -> bool {
1487        self.closed
1488    }
1489
1490    /// Flush any pending control responses
1491    pub async fn flush(&mut self) -> Result<()> {
1492        self.process_control_requests().await?;
1493        self.writer.flush().await.map_err(Into::into)
1494    }
1495}
1496
1497#[cfg(test)]
1498mod tests {
1499    use super::*;
1500
1501    // Tests would require a mock async transport
1502    // For now, we just verify the types compile correctly
1503
1504    #[test]
1505    fn test_builder() {
1506        let _builder = WebSocketStreamBuilder::new()
1507            .role(Role::Server)
1508            .max_message_size(1024 * 1024)
1509            .max_frame_size(64 * 1024);
1510    }
1511}