fastwebsockets_stream/
stream.rs

1use bytes::Bytes;
2use bytes::BytesMut;
3use fastwebsockets::{Frame, OpCode, Payload, WebSocket, WebSocketError};
4use futures::FutureExt;
5use futures::future::BoxFuture;
6use std::fmt::Debug;
7use std::io;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11
12/// Future output type for operations that temporarily own the websocket.
13///
14/// The future returns either an owned `WebSocket<S>` back together with a
15/// result value `T`, or a `WebSocketError` if the operation failed.
16type FutureResult<S, T> = Result<(WebSocket<S>, T), WebSocketError>;
17
18/// Internal owned frame representation.
19///
20/// When we read a frame from `WebSocket::read_frame()` it borrows internal
21/// buffers. To be able to return both the websocket and the payload across an
22/// `await` point we copy the payload into an owned `Bytes` and store the opcode.
23struct PayloadFrame {
24    /// Opcode of the frame (Text/Binary/Close/etc).
25    opcode: OpCode,
26    /// Owned payload bytes of the frame.
27    payload: Bytes,
28}
29
30/// Read state machine for `WebSocketStream`.
31///
32/// We encode whether we are idle or currently running an owned future that has
33/// taken ownership of the underlying `WebSocket` to perform an asynchronous
34/// read operation. The owned future returns the websocket together with the
35/// read `PayloadFrame`.
36enum ReadState<S> {
37    /// No read in progress.
38    Idle,
39    /// A boxed future that owns the websocket and will produce a `PayloadFrame`
40    /// (and the websocket) when complete.
41    Reading(BoxFuture<'static, FutureResult<S, PayloadFrame>>),
42}
43
44/// Write state machine for `WebSocketStream`.
45///
46/// Similar to `ReadState`, but represents a write operation that owns the
47/// websocket until it completes.
48enum WriteState<S> {
49    /// No write in progress.
50    Idle,
51    /// A boxed future that owns the websocket and will complete the write,
52    /// returning the websocket.
53    Writing(BoxFuture<'static, FutureResult<S, ()>>),
54}
55
56/// Stream payload type.
57///
58/// This enum specifies whether the `WebSocketStream` will send/receive Text or
59/// Binary application data. It is used to construct frames when writing and
60/// validated on frames read from the peer.
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum PayloadType {
63    /// Binary frames.
64    Binary,
65    /// UTF-8 Text frames.
66    Text,
67}
68
69impl From<PayloadType> for OpCode {
70    fn from(value: PayloadType) -> Self {
71        match value {
72            PayloadType::Binary => OpCode::Binary,
73            PayloadType::Text => OpCode::Text,
74        }
75    }
76}
77
78/// Map a `WebSocketError` into an `io::Error` for compatibility with the
79/// `AsyncRead`/`AsyncWrite` trait surfaces.
80fn make_io_err(e: WebSocketError) -> io::Error {
81    io::Error::other(format!("Websocket error: {}", e))
82}
83
84/// Helper: create a boxed future that owns the websocket and reads a frame.
85///
86/// The returned future will call `websocket.read_frame().await`, copy the
87/// payload into an owned `Bytes`, and return `(websocket, PayloadFrame)` on
88/// success or `WebSocketError` on failure.
89///
90/// This helper is private because it requires taking ownership of the
91/// `WebSocket` (which is stored as `Option` inside `WebSocketStream`) and
92/// boxing the resulting future so the `WebSocketStream` state machine can store
93/// it.
94fn read<S>(mut websocket: WebSocket<S>) -> BoxFuture<'static, FutureResult<S, PayloadFrame>>
95where
96    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
97{
98    async move {
99        // read_frame() returns Frame<'_> which borrows the websocket's buffers;
100        // we immediately copy the payload into an owned Bytes so the PayloadFrame
101        // can be returned with the websocket.
102        match websocket.read_frame().await {
103            Ok(frame) => {
104                let payload = match frame.payload {
105                    Payload::BorrowedMut(buf) => Bytes::from(buf.to_vec()),
106                    Payload::Borrowed(buf) => Bytes::from(buf.to_vec()),
107                    Payload::Owned(vec) => Bytes::from(vec),
108                    Payload::Bytes(bytes) => bytes.freeze(),
109                };
110
111                let owned = PayloadFrame {
112                    opcode: frame.opcode,
113                    payload,
114                };
115                Ok((websocket, owned))
116            }
117            Err(e) => Err(e),
118        }
119    }
120    .boxed()
121}
122
123/// Helper: create a boxed future that owns the websocket and writes the provided payload.
124///
125/// This helper constructs a single-frame message with the chosen `payload_type`
126/// (Text or Binary) and writes it with `websocket.write_frame(...)`. The future
127/// returns the websocket on success so ownership can be restored to the stream.
128fn write<S>(
129    mut websocket: WebSocket<S>,
130    payload: BytesMut,
131    payload_type: PayloadType,
132) -> BoxFuture<'static, FutureResult<S, ()>>
133where
134    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
135{
136    async move {
137        let frame = Frame::new(true, payload_type.into(), None, Payload::Bytes(payload));
138        match websocket.write_frame(frame).await {
139            Ok(()) => Ok((websocket, ())),
140            Err(e) => Err(e),
141        }
142    }
143    .boxed()
144}
145
146/// Helper: create a boxed future that owns the websocket and flushes it.
147///
148/// This issues a flush on the underlying `WebSocket` (which may flush any
149/// internal write buffers) and returns the websocket afterwards.
150fn flush<S>(mut websocket: WebSocket<S>) -> BoxFuture<'static, FutureResult<S, ()>>
151where
152    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
153{
154    async move {
155        match websocket.flush().await {
156            Ok(()) => Ok((websocket, ())),
157            Err(e) => Err(e),
158        }
159    }
160    .boxed()
161}
162
163/// Helper: create a boxed future that owns the websocket and sends a Close frame.
164///
165/// This writes a close frame and returns the websocket. Used by `poll_shutdown`.
166fn close<S>(mut websocket: WebSocket<S>) -> BoxFuture<'static, FutureResult<S, ()>>
167where
168    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
169{
170    async move {
171        let frame = Frame::close_raw(Vec::new().into());
172        match websocket.write_frame(frame).await {
173            Ok(()) => Ok((websocket, ())),
174            Err(e) => Err(e),
175        }
176    }
177    .boxed()
178}
179
180/// An `AsyncRead` / `AsyncWrite` adapter over a `fastwebsockets::WebSocket`.
181///
182/// `WebSocketStream<S>` wraps a `WebSocket<S>` and exposes a byte-stream view
183/// (implementing `tokio::io::AsyncRead` and `tokio::io::AsyncWrite`) so that
184/// websocket application payloads can be used with existing I/O and codec
185/// infrastructure such as `tokio_util::codec::Framed`.
186///
187/// ## Behavior
188///
189/// * Incoming WebSocket data frames (Text or Binary depending on the stream's
190///   `PayloadType`) are presented as a continuous byte stream. Each data frame's
191///   payload is returned in-order; if a read buffer provided by the caller is
192///   smaller than a frame payload, the remainder is buffered internally and
193///   served on subsequent reads.
194/// * Control frames (Ping/Pong) are handled by the underlying `WebSocket`
195///   (auto-pong) or ignored by this adapter. A `Close` frame marks EOF and
196///   subsequent reads return `Ok(())` with zero bytes (standard EOF semantics).
197/// * Writes produce single complete WebSocket data frames of the configured
198///   `PayloadType`. Each `poll_write` call sends one WebSocket data frame with
199///   the provided bytes as payload. The number of bytes reported as written is
200///   the length of `buf` supplied to `poll_write`.
201///
202/// ## Notes on threading and ownership
203///
204/// The adapter temporarily takes ownership of the inner `WebSocket` when it
205/// needs to perform an asynchronous read or write operation. To achieve this
206/// without requiring `WebSocket` itself to be `Sync`/`Send` across await points
207/// we spawn a boxed future that owns the websocket and returns it when the
208/// operation completes. This is implemented internally using `ReadState` and
209/// `WriteState`.
210///
211/// ## Example
212///
213/// ```rust
214/// use tokio::io::{AsyncReadExt, AsyncWriteExt};
215/// use tokio::net::TcpStream;
216/// use fastwebsockets::WebSocket;
217/// use fastwebsockets_stream::{WebSocketStream, PayloadType};
218///
219/// // Wrap the websocket and apply a line-based codec:
220/// async fn example<S>(_ws: WebSocket<S>)
221///     where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static {
222///     // This example is illustrative: constructing a real `WebSocket` requires
223///     // an underlying transport (e.g. a `TcpStream`) and the fastwebsockets
224///     // connection/handshake. Assume `ws` is a valid WebSocket<TcpStream>.
225///
226///     let ws: WebSocket<S> = unimplemented!();
227///     let mut ws_stream = WebSocketStream::new(ws, PayloadType::Binary);
228///
229///     // Write bytes -> sends a Binary frame
230///     let _n = ws_stream.write(b"hello").await;
231///
232///     // Read bytes
233///     let mut buf = vec![0_u8; 1024];
234///     let _ = ws_stream.read(&mut buf).await;
235///
236///     // Shutdown (sends Close)
237///     let _ = ws_stream.shutdown().await;
238/// }
239/// ```
240///
241/// Another common usage is to use `tokio_util::codec::Framed` to apply a codec
242/// on top of `WebSocketStream` (for example a length-delimited or line-based
243/// codec). Example:
244///
245/// ```rust
246/// use tokio_util::codec::{Framed, LinesCodec};
247/// use fastwebsockets::WebSocket;
248/// use fastwebsockets_stream::{WebSocketStream, PayloadType};
249///
250/// // Wrap the websocket and apply a line-based codec:
251/// async fn example<S>(_ws: WebSocket<S>)
252///     where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static {
253///     let ws: WebSocket<S> = unimplemented!();
254///     let stream = WebSocketStream::new(ws, PayloadType::Text);
255///     let mut framed = Framed::new(stream, LinesCodec::new());
256///
257///     // Now you can use framed.read() / framed.send() to work with String frames.
258/// }
259/// ```
260pub struct WebSocketStream<S> {
261    /// The inner websocket. Stored as `Option`
262    /// to allow temporarily taking ownership when starting an owned future
263    websocket: Option<WebSocket<S>>,
264
265    /// Buffer containing leftover bytes from the current
266    /// incoming message that didn't fit the last caller-provided read buffer
267    read_buf: BytesMut,
268
269    /// State machine for an in-progress read future that owns the websocket
270    read_state: ReadState<S>,
271
272    /// State machine for an in-progress write future that owns the websocket
273    write_state: WriteState<S>,
274
275    /// If `Some(n)` then a write is in progress and intends to report `n` bytes
276    /// written when the write future completes. We store the length separately
277    /// because the actual write future only stores the websocket and the
278    /// payload it sent
279    pending_write_len: Option<usize>,
280
281    /// Expected and emitted payload type (Text or Binary). Received frames with
282    /// a different data opcode are treated as errors
283    payload_type: PayloadType,
284
285    /// Set to `true` after a Close frame has been observed.
286    /// When `closed` is true, subsequent reads return EOF
287    closed: bool,
288}
289
290impl<S> WebSocketStream<S>
291where
292    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
293{
294    /// Create a new `WebSocketStream` wrapping the provided `WebSocket`.
295    ///
296    /// This will enable automatic Pong replies and automatic Close handling on
297    /// the wrapped `WebSocket` and initialize internal buffers and state.
298    ///
299    /// `payload_type` selects whether this stream should read/write Text or
300    /// Binary data. If the peer sends data frames with an opcode that does not
301    /// match `payload_type`, reads will return an error.
302    pub fn new(mut websocket: WebSocket<S>, payload_type: PayloadType) -> Self {
303        // Set auto pong and close
304        websocket.set_auto_pong(true);
305        websocket.set_auto_close(true);
306
307        Self {
308            websocket: Some(websocket),
309            read_buf: BytesMut::with_capacity(8 * 1024),
310            read_state: ReadState::Idle,
311            write_state: WriteState::Idle,
312            pending_write_len: None,
313            payload_type,
314            closed: false,
315        }
316    }
317
318    /// Consume the adapter and attempt to return the inner `WebSocket`.
319    ///
320    /// This returns `Some(WebSocket<S>)` if the websocket currently resides in
321    /// the adapter. If there is an outstanding future that currently owns the
322    /// websocket (i.e. a read or write in progress) this method will return
323    /// `None` because the adapter cannot recover the websocket until that
324    /// future completes.
325    pub fn into_inner(mut self) -> Option<WebSocket<S>> {
326        // If there is an outstanding future that currently owns the websocket,
327        // we cannot recover it here. We only return the inner websocket if it
328        // currently resides in `self.ws`.
329        self.websocket.take()
330    }
331
332    /// Returns `true` if we've observed a Close frame from the peer and the
333    /// stream reached EOF.
334    pub fn is_closed(&self) -> bool {
335        self.closed
336    }
337}
338
339impl<S> AsyncRead for WebSocketStream<S>
340where
341    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
342{
343    fn poll_read(
344        mut self: Pin<&mut Self>,
345        cx: &mut Context<'_>,
346        buf: &mut ReadBuf<'_>,
347    ) -> Poll<io::Result<()>> {
348        // If there are buffered bytes from previous frame, satisfy the read.
349        if !self.read_buf.is_empty() {
350            let to_copy = std::cmp::min(self.read_buf.len(), buf.remaining());
351            buf.put_slice(&self.read_buf.split_to(to_copy));
352            return Poll::Ready(Ok(()));
353        }
354
355        // If we've previously observed Close/EOF, report EOF by returning Ok(())
356        if self.closed {
357            return Poll::Ready(Ok(()));
358        }
359
360        loop {
361            // Match current read future state
362            match &mut self.read_state {
363                ReadState::Idle => {
364                    // Start a new read future by taking the websocket
365                    let websocket = match self.websocket.take() {
366                        Some(websocket) => websocket,
367                        None => {
368                            return Poll::Ready(Err(io::Error::other("Websocket not available")));
369                        }
370                    };
371                    let future = read(websocket);
372                    self.read_state = ReadState::Reading(future);
373                }
374                ReadState::Reading(fut) => {
375                    // Poll the future. If Pending, return Pending. If Ready,
376                    // reinstate websocket and handle frame.
377                    let mut future_pin = unsafe { Pin::new_unchecked(fut) };
378                    match future_pin.as_mut().poll(cx) {
379                        Poll::Pending => return Poll::Pending,
380                        Poll::Ready(res) => {
381                            // Transition back to Idle
382                            self.read_state = ReadState::Idle;
383                            match res {
384                                Ok((websocket, frame)) => {
385                                    // Put websocket back
386                                    self.websocket = Some(websocket);
387
388                                    match frame.opcode {
389                                        OpCode::Binary | OpCode::Text => {
390                                            // If frame payload type isn't match the desired type,
391                                            // return error
392                                            if frame.opcode != self.payload_type.into() {
393                                                return Poll::Ready(Err(io::Error::other(
394                                                    "The received data type is different \
395                                                    from the stream data type",
396                                                )));
397                                            }
398
399                                            // Check frame payload
400                                            let payload = frame.payload;
401                                            if payload.is_empty() {
402                                                // Nothing to return; loop to read next frame
403                                                continue;
404                                            }
405
406                                            // If payload fits entirely into buf, copy and return.
407                                            return if payload.len() <= buf.remaining() {
408                                                buf.put_slice(&payload);
409                                                Poll::Ready(Ok(()))
410                                            } else {
411                                                // Copy a part and stash remainder
412                                                let take = buf.remaining();
413                                                buf.put_slice(&payload[..take]);
414                                                self.read_buf.extend_from_slice(&payload[take..]);
415                                                Poll::Ready(Ok(()))
416                                            };
417                                        }
418
419                                        OpCode::Close => {
420                                            // Mark EOF and return 0 bytes read (Ok(()))
421                                            self.closed = true;
422                                            return Poll::Ready(Ok(()));
423                                        }
424                                        _ => {
425                                            // Ignore control frames and loop to read next frame
426                                            continue;
427                                        }
428                                    }
429                                }
430                                Err(e) => {
431                                    // restore websocket if possible? We don't have it on error.
432                                    // Map error to io::Error
433                                    return Poll::Ready(Err(make_io_err(e)));
434                                }
435                            }
436                        }
437                    }
438                }
439            }
440        }
441    }
442}
443
444impl<S> AsyncWrite for WebSocketStream<S>
445where
446    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
447{
448    fn poll_write(
449        mut self: Pin<&mut Self>,
450        cx: &mut Context<'_>,
451        buf: &[u8],
452    ) -> Poll<io::Result<usize>> {
453        // If there's already a write-in progress, poll it.
454        loop {
455            match &mut self.write_state {
456                WriteState::Idle => {
457                    // Start a new write: take websocket and create future that writes
458                    let websocket = match self.websocket.take() {
459                        Some(websocket) => websocket,
460                        None => {
461                            return Poll::Ready(Err(io::Error::other("Websocket not available")));
462                        }
463                    };
464
465                    // Copy buffer into owned Vec so the future can own it
466                    let payload = BytesMut::from(buf);
467                    let len = payload.len();
468                    let future = write(websocket, payload, self.payload_type);
469                    self.pending_write_len = Some(len);
470                    self.write_state = WriteState::Writing(future);
471                }
472                WriteState::Writing(fut) => {
473                    // poll the write future
474                    let mut future_pin = unsafe { Pin::new_unchecked(fut) };
475                    match future_pin.as_mut().poll(cx) {
476                        Poll::Pending => return Poll::Pending,
477
478                        Poll::Ready(res) => {
479                            // finish write: put websocket back
480                            self.write_state = WriteState::Idle;
481                            match res {
482                                Ok((websocket, ())) => {
483                                    self.websocket = Some(websocket);
484                                    let n = self.pending_write_len.take().unwrap_or(0);
485                                    return Poll::Ready(Ok(n));
486                                }
487                                Err(e) => return Poll::Ready(Err(make_io_err(e))),
488                            }
489                        }
490                    }
491                }
492            }
493        }
494    }
495
496    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
497        // If a write is in progress, poll it first.
498        match &mut self.write_state {
499            WriteState::Writing(_) => {
500                // let regular poll_write flow handle it; return Pending so caller
501                // should call poll_flush again later. Alternatively, we could
502                // poll it here explicitly, but reusing poll_write semantics is fine.
503                return Poll::Pending;
504            }
505            WriteState::Idle => {
506                // Start a new flush future by taking the websocket
507                let websocket = match self.websocket.take() {
508                    Some(websocket) => websocket,
509                    None => return Poll::Ready(Ok(())),
510                };
511                // empty payload for close
512                let future = flush(websocket);
513                self.write_state = WriteState::Writing(future);
514
515                // fallthrough to poll the just-created future
516            }
517        }
518
519        // Now poll the write future created above.
520        match &mut self.write_state {
521            WriteState::Writing(fut) => {
522                let mut fut_pin = unsafe { Pin::new_unchecked(fut) };
523                match fut_pin.as_mut().poll(cx) {
524                    Poll::Pending => Poll::Pending,
525                    Poll::Ready(res) => {
526                        self.write_state = WriteState::Idle;
527                        match res {
528                            Ok((websocket, ())) => {
529                                self.websocket = Some(websocket);
530                                Poll::Ready(Ok(()))
531                            }
532                            Err(e) => Poll::Ready(Err(make_io_err(e))),
533                        }
534                    }
535                }
536            }
537            _ => Poll::Ready(Ok(())),
538        }
539    }
540
541    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
542        // Implement shutdown by sending a Close frame synchronously via the
543        // same state-machine approach: start a write future that sends close.
544        // If a write is already in progress, wait for it to complete first.
545
546        // If a write is in progress, poll it first.
547        match &mut self.write_state {
548            WriteState::Writing(_) => {
549                // let regular poll_write flow handle it; return Pending so caller
550                // should call poll_shutdown again later. Alternatively, we could
551                // poll it here explicitly, but reusing poll_write semantics is fine.
552                return Poll::Pending;
553            }
554            WriteState::Idle => {
555                // start a close write
556                let websocket = match self.websocket.take() {
557                    Some(websocket) => websocket,
558                    None => return Poll::Ready(Ok(())),
559                };
560                // empty payload for close
561                let future = close(websocket);
562                self.write_state = WriteState::Writing(future);
563
564                // fallthrough to poll the just-created future
565            }
566        }
567
568        // Now poll the write future created above.
569        match &mut self.write_state {
570            WriteState::Writing(fut) => {
571                let mut fut_pin = unsafe { Pin::new_unchecked(fut) };
572                match fut_pin.as_mut().poll(cx) {
573                    Poll::Pending => Poll::Pending,
574                    Poll::Ready(res) => {
575                        self.write_state = WriteState::Idle;
576                        match res {
577                            Ok((websocket, ())) => {
578                                self.websocket = Some(websocket);
579                                Poll::Ready(Ok(()))
580                            }
581                            Err(e) => Poll::Ready(Err(make_io_err(e))),
582                        }
583                    }
584                }
585            }
586            _ => Poll::Ready(Ok(())),
587        }
588    }
589}
590
591impl<S> Debug for WebSocketStream<S> {
592    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
593        // Helper to stringify read_state/write_state variants without requiring Debug on futures.
594        fn read_state_name<T>(s: &ReadState<T>) -> &'static str {
595            match s {
596                ReadState::Idle => "Idle",
597                ReadState::Reading(_) => "Reading",
598            }
599        }
600
601        fn write_state_name<T>(s: &WriteState<T>) -> &'static str {
602            match s {
603                WriteState::Idle => "Idle",
604                WriteState::Writing(_) => "Writing",
605            }
606        }
607
608        f.debug_struct("WebSocketStream")
609            .field("read_buf_len", &self.read_buf.len())
610            .field("read_state", &read_state_name(&self.read_state))
611            .field("write_state", &write_state_name(&self.write_state))
612            .field("pending_write_len", &self.pending_write_len)
613            .field("closed", &self.closed)
614            .finish()
615    }
616}