fastwebsockets_stream/stream.rs
1use bytes::Bytes;
2use bytes::BytesMut;
3use fastwebsockets::{Frame, OpCode, Payload, WebSocket, WebSocketError};
4use futures::FutureExt;
5use futures::future::BoxFuture;
6use std::fmt::Debug;
7use std::io;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11
12/// Future output type for operations that temporarily own the websocket.
13///
14/// The future returns either an owned `WebSocket<S>` back together with a
15/// result value `T`, or a `WebSocketError` if the operation failed.
16type FutureResult<S, T> = Result<(WebSocket<S>, T), WebSocketError>;
17
18/// Internal owned frame representation.
19///
20/// When we read a frame from `WebSocket::read_frame()` it borrows internal
21/// buffers. To be able to return both the websocket and the payload across an
22/// `await` point we copy the payload into an owned `Bytes` and store the opcode.
23struct PayloadFrame {
24 /// Opcode of the frame (Text/Binary/Close/etc).
25 opcode: OpCode,
26 /// Owned payload bytes of the frame.
27 payload: Bytes,
28}
29
30/// Read state machine for `WebSocketStream`.
31///
32/// We encode whether we are idle or currently running an owned future that has
33/// taken ownership of the underlying `WebSocket` to perform an asynchronous
34/// read operation. The owned future returns the websocket together with the
35/// read `PayloadFrame`.
36enum ReadState<S> {
37 /// No read in progress.
38 Idle,
39 /// A boxed future that owns the websocket and will produce a `PayloadFrame`
40 /// (and the websocket) when complete.
41 Reading(BoxFuture<'static, FutureResult<S, PayloadFrame>>),
42}
43
44/// Write state machine for `WebSocketStream`.
45///
46/// Similar to `ReadState`, but represents a write operation that owns the
47/// websocket until it completes.
48enum WriteState<S> {
49 /// No write in progress.
50 Idle,
51 /// A boxed future that owns the websocket and will complete the write,
52 /// returning the websocket.
53 Writing(BoxFuture<'static, FutureResult<S, ()>>),
54}
55
56/// Stream payload type.
57///
58/// This enum specifies whether the `WebSocketStream` will send/receive Text or
59/// Binary application data. It is used to construct frames when writing and
60/// validated on frames read from the peer.
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum PayloadType {
63 /// Binary frames.
64 Binary,
65 /// UTF-8 Text frames.
66 Text,
67}
68
69impl From<PayloadType> for OpCode {
70 fn from(value: PayloadType) -> Self {
71 match value {
72 PayloadType::Binary => OpCode::Binary,
73 PayloadType::Text => OpCode::Text,
74 }
75 }
76}
77
78/// Map a `WebSocketError` into an `io::Error` for compatibility with the
79/// `AsyncRead`/`AsyncWrite` trait surfaces.
80fn make_io_err(e: WebSocketError) -> io::Error {
81 io::Error::other(format!("Websocket error: {}", e))
82}
83
84/// Helper: create a boxed future that owns the websocket and reads a frame.
85///
86/// The returned future will call `websocket.read_frame().await`, copy the
87/// payload into an owned `Bytes`, and return `(websocket, PayloadFrame)` on
88/// success or `WebSocketError` on failure.
89///
90/// This helper is private because it requires taking ownership of the
91/// `WebSocket` (which is stored as `Option` inside `WebSocketStream`) and
92/// boxing the resulting future so the `WebSocketStream` state machine can store
93/// it.
94fn read<S>(mut websocket: WebSocket<S>) -> BoxFuture<'static, FutureResult<S, PayloadFrame>>
95where
96 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
97{
98 async move {
99 // read_frame() returns Frame<'_> which borrows the websocket's buffers;
100 // we immediately copy the payload into an owned Bytes so the PayloadFrame
101 // can be returned with the websocket.
102 match websocket.read_frame().await {
103 Ok(frame) => {
104 let payload = match frame.payload {
105 Payload::BorrowedMut(buf) => Bytes::from(buf.to_vec()),
106 Payload::Borrowed(buf) => Bytes::from(buf.to_vec()),
107 Payload::Owned(vec) => Bytes::from(vec),
108 Payload::Bytes(bytes) => bytes.freeze(),
109 };
110
111 let owned = PayloadFrame {
112 opcode: frame.opcode,
113 payload,
114 };
115 Ok((websocket, owned))
116 }
117 Err(e) => Err(e),
118 }
119 }
120 .boxed()
121}
122
123/// Helper: create a boxed future that owns the websocket and writes the provided payload.
124///
125/// This helper constructs a single-frame message with the chosen `payload_type`
126/// (Text or Binary) and writes it with `websocket.write_frame(...)`. The future
127/// returns the websocket on success so ownership can be restored to the stream.
128fn write<S>(
129 mut websocket: WebSocket<S>,
130 payload: BytesMut,
131 payload_type: PayloadType,
132) -> BoxFuture<'static, FutureResult<S, ()>>
133where
134 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
135{
136 async move {
137 let frame = Frame::new(true, payload_type.into(), None, Payload::Bytes(payload));
138 match websocket.write_frame(frame).await {
139 Ok(()) => Ok((websocket, ())),
140 Err(e) => Err(e),
141 }
142 }
143 .boxed()
144}
145
146/// Helper: create a boxed future that owns the websocket and flushes it.
147///
148/// This issues a flush on the underlying `WebSocket` (which may flush any
149/// internal write buffers) and returns the websocket afterwards.
150fn flush<S>(mut websocket: WebSocket<S>) -> BoxFuture<'static, FutureResult<S, ()>>
151where
152 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
153{
154 async move {
155 match websocket.flush().await {
156 Ok(()) => Ok((websocket, ())),
157 Err(e) => Err(e),
158 }
159 }
160 .boxed()
161}
162
163/// Helper: create a boxed future that owns the websocket and sends a Close frame.
164///
165/// This writes a close frame and returns the websocket. Used by `poll_shutdown`.
166fn close<S>(mut websocket: WebSocket<S>) -> BoxFuture<'static, FutureResult<S, ()>>
167where
168 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
169{
170 async move {
171 let frame = Frame::close_raw(Vec::new().into());
172 match websocket.write_frame(frame).await {
173 Ok(()) => Ok((websocket, ())),
174 Err(e) => Err(e),
175 }
176 }
177 .boxed()
178}
179
180/// An `AsyncRead` / `AsyncWrite` adapter over a `fastwebsockets::WebSocket`.
181///
182/// `WebSocketStream<S>` wraps a `WebSocket<S>` and exposes a byte-stream view
183/// (implementing `tokio::io::AsyncRead` and `tokio::io::AsyncWrite`) so that
184/// websocket application payloads can be used with existing I/O and codec
185/// infrastructure such as `tokio_util::codec::Framed`.
186///
187/// ## Behavior
188///
189/// * Incoming WebSocket data frames (Text or Binary depending on the stream's
190/// `PayloadType`) are presented as a continuous byte stream. Each data frame's
191/// payload is returned in-order; if a read buffer provided by the caller is
192/// smaller than a frame payload, the remainder is buffered internally and
193/// served on subsequent reads.
194/// * Control frames (Ping/Pong) are handled by the underlying `WebSocket`
195/// (auto-pong) or ignored by this adapter. A `Close` frame marks EOF and
196/// subsequent reads return `Ok(())` with zero bytes (standard EOF semantics).
197/// * Writes produce single complete WebSocket data frames of the configured
198/// `PayloadType`. Each `poll_write` call sends one WebSocket data frame with
199/// the provided bytes as payload. The number of bytes reported as written is
200/// the length of `buf` supplied to `poll_write`.
201///
202/// ## Notes on threading and ownership
203///
204/// The adapter temporarily takes ownership of the inner `WebSocket` when it
205/// needs to perform an asynchronous read or write operation. To achieve this
206/// without requiring `WebSocket` itself to be `Sync`/`Send` across await points
207/// we spawn a boxed future that owns the websocket and returns it when the
208/// operation completes. This is implemented internally using `ReadState` and
209/// `WriteState`.
210///
211/// ## Example
212///
213/// ```rust
214/// use tokio::io::{AsyncReadExt, AsyncWriteExt};
215/// use tokio::net::TcpStream;
216/// use fastwebsockets::WebSocket;
217/// use fastwebsockets_stream::{WebSocketStream, PayloadType};
218///
219/// // Wrap the websocket and apply a line-based codec:
220/// async fn example<S>(_ws: WebSocket<S>)
221/// where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static {
222/// // This example is illustrative: constructing a real `WebSocket` requires
223/// // an underlying transport (e.g. a `TcpStream`) and the fastwebsockets
224/// // connection/handshake. Assume `ws` is a valid WebSocket<TcpStream>.
225///
226/// let ws: WebSocket<S> = unimplemented!();
227/// let mut ws_stream = WebSocketStream::new(ws, PayloadType::Binary);
228///
229/// // Write bytes -> sends a Binary frame
230/// let _n = ws_stream.write(b"hello").await;
231///
232/// // Read bytes
233/// let mut buf = vec![0_u8; 1024];
234/// let _ = ws_stream.read(&mut buf).await;
235///
236/// // Shutdown (sends Close)
237/// let _ = ws_stream.shutdown().await;
238/// }
239/// ```
240///
241/// Another common usage is to use `tokio_util::codec::Framed` to apply a codec
242/// on top of `WebSocketStream` (for example a length-delimited or line-based
243/// codec). Example:
244///
245/// ```rust
246/// use tokio_util::codec::{Framed, LinesCodec};
247/// use fastwebsockets::WebSocket;
248/// use fastwebsockets_stream::{WebSocketStream, PayloadType};
249///
250/// // Wrap the websocket and apply a line-based codec:
251/// async fn example<S>(_ws: WebSocket<S>)
252/// where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static {
253/// let ws: WebSocket<S> = unimplemented!();
254/// let stream = WebSocketStream::new(ws, PayloadType::Text);
255/// let mut framed = Framed::new(stream, LinesCodec::new());
256///
257/// // Now you can use framed.read() / framed.send() to work with String frames.
258/// }
259/// ```
260pub struct WebSocketStream<S> {
261 /// The inner websocket. Stored as `Option`
262 /// to allow temporarily taking ownership when starting an owned future
263 websocket: Option<WebSocket<S>>,
264
265 /// Buffer containing leftover bytes from the current
266 /// incoming message that didn't fit the last caller-provided read buffer
267 read_buf: BytesMut,
268
269 /// State machine for an in-progress read future that owns the websocket
270 read_state: ReadState<S>,
271
272 /// State machine for an in-progress write future that owns the websocket
273 write_state: WriteState<S>,
274
275 /// If `Some(n)` then a write is in progress and intends to report `n` bytes
276 /// written when the write future completes. We store the length separately
277 /// because the actual write future only stores the websocket and the
278 /// payload it sent
279 pending_write_len: Option<usize>,
280
281 /// Expected and emitted payload type (Text or Binary). Received frames with
282 /// a different data opcode are treated as errors
283 payload_type: PayloadType,
284
285 /// Set to `true` after a Close frame has been observed.
286 /// When `closed` is true, subsequent reads return EOF
287 closed: bool,
288}
289
290impl<S> WebSocketStream<S>
291where
292 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
293{
294 /// Create a new `WebSocketStream` wrapping the provided `WebSocket`.
295 ///
296 /// This will enable automatic Pong replies and automatic Close handling on
297 /// the wrapped `WebSocket` and initialize internal buffers and state.
298 ///
299 /// `payload_type` selects whether this stream should read/write Text or
300 /// Binary data. If the peer sends data frames with an opcode that does not
301 /// match `payload_type`, reads will return an error.
302 pub fn new(mut websocket: WebSocket<S>, payload_type: PayloadType) -> Self {
303 // Set auto pong and close
304 websocket.set_auto_pong(true);
305 websocket.set_auto_close(true);
306
307 Self {
308 websocket: Some(websocket),
309 read_buf: BytesMut::with_capacity(8 * 1024),
310 read_state: ReadState::Idle,
311 write_state: WriteState::Idle,
312 pending_write_len: None,
313 payload_type,
314 closed: false,
315 }
316 }
317
318 /// Consume the adapter and attempt to return the inner `WebSocket`.
319 ///
320 /// This returns `Some(WebSocket<S>)` if the websocket currently resides in
321 /// the adapter. If there is an outstanding future that currently owns the
322 /// websocket (i.e. a read or write in progress) this method will return
323 /// `None` because the adapter cannot recover the websocket until that
324 /// future completes.
325 pub fn into_inner(mut self) -> Option<WebSocket<S>> {
326 // If there is an outstanding future that currently owns the websocket,
327 // we cannot recover it here. We only return the inner websocket if it
328 // currently resides in `self.ws`.
329 self.websocket.take()
330 }
331
332 /// Returns `true` if we've observed a Close frame from the peer and the
333 /// stream reached EOF.
334 pub fn is_closed(&self) -> bool {
335 self.closed
336 }
337}
338
339impl<S> AsyncRead for WebSocketStream<S>
340where
341 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
342{
343 fn poll_read(
344 mut self: Pin<&mut Self>,
345 cx: &mut Context<'_>,
346 buf: &mut ReadBuf<'_>,
347 ) -> Poll<io::Result<()>> {
348 // If there are buffered bytes from previous frame, satisfy the read.
349 if !self.read_buf.is_empty() {
350 let to_copy = std::cmp::min(self.read_buf.len(), buf.remaining());
351 buf.put_slice(&self.read_buf.split_to(to_copy));
352 return Poll::Ready(Ok(()));
353 }
354
355 // If we've previously observed Close/EOF, report EOF by returning Ok(())
356 if self.closed {
357 return Poll::Ready(Ok(()));
358 }
359
360 loop {
361 // Match current read future state
362 match &mut self.read_state {
363 ReadState::Idle => {
364 // Start a new read future by taking the websocket
365 let websocket = match self.websocket.take() {
366 Some(websocket) => websocket,
367 None => {
368 return Poll::Ready(Err(io::Error::other("Websocket not available")));
369 }
370 };
371 let future = read(websocket);
372 self.read_state = ReadState::Reading(future);
373 }
374 ReadState::Reading(fut) => {
375 // Poll the future. If Pending, return Pending. If Ready,
376 // reinstate websocket and handle frame.
377 let mut future_pin = unsafe { Pin::new_unchecked(fut) };
378 match future_pin.as_mut().poll(cx) {
379 Poll::Pending => return Poll::Pending,
380 Poll::Ready(res) => {
381 // Transition back to Idle
382 self.read_state = ReadState::Idle;
383 match res {
384 Ok((websocket, frame)) => {
385 // Put websocket back
386 self.websocket = Some(websocket);
387
388 match frame.opcode {
389 OpCode::Binary | OpCode::Text => {
390 // If frame payload type isn't match the desired type,
391 // return error
392 if frame.opcode != self.payload_type.into() {
393 return Poll::Ready(Err(io::Error::other(
394 "The received data type is different \
395 from the stream data type",
396 )));
397 }
398
399 // Check frame payload
400 let payload = frame.payload;
401 if payload.is_empty() {
402 // Nothing to return; loop to read next frame
403 continue;
404 }
405
406 // If payload fits entirely into buf, copy and return.
407 return if payload.len() <= buf.remaining() {
408 buf.put_slice(&payload);
409 Poll::Ready(Ok(()))
410 } else {
411 // Copy a part and stash remainder
412 let take = buf.remaining();
413 buf.put_slice(&payload[..take]);
414 self.read_buf.extend_from_slice(&payload[take..]);
415 Poll::Ready(Ok(()))
416 };
417 }
418
419 OpCode::Close => {
420 // Mark EOF and return 0 bytes read (Ok(()))
421 self.closed = true;
422 return Poll::Ready(Ok(()));
423 }
424 _ => {
425 // Ignore control frames and loop to read next frame
426 continue;
427 }
428 }
429 }
430 Err(e) => {
431 // restore websocket if possible? We don't have it on error.
432 // Map error to io::Error
433 return Poll::Ready(Err(make_io_err(e)));
434 }
435 }
436 }
437 }
438 }
439 }
440 }
441 }
442}
443
444impl<S> AsyncWrite for WebSocketStream<S>
445where
446 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
447{
448 fn poll_write(
449 mut self: Pin<&mut Self>,
450 cx: &mut Context<'_>,
451 buf: &[u8],
452 ) -> Poll<io::Result<usize>> {
453 // If there's already a write-in progress, poll it.
454 loop {
455 match &mut self.write_state {
456 WriteState::Idle => {
457 // Start a new write: take websocket and create future that writes
458 let websocket = match self.websocket.take() {
459 Some(websocket) => websocket,
460 None => {
461 return Poll::Ready(Err(io::Error::other("Websocket not available")));
462 }
463 };
464
465 // Copy buffer into owned Vec so the future can own it
466 let payload = BytesMut::from(buf);
467 let len = payload.len();
468 let future = write(websocket, payload, self.payload_type);
469 self.pending_write_len = Some(len);
470 self.write_state = WriteState::Writing(future);
471 }
472 WriteState::Writing(fut) => {
473 // poll the write future
474 let mut future_pin = unsafe { Pin::new_unchecked(fut) };
475 match future_pin.as_mut().poll(cx) {
476 Poll::Pending => return Poll::Pending,
477
478 Poll::Ready(res) => {
479 // finish write: put websocket back
480 self.write_state = WriteState::Idle;
481 match res {
482 Ok((websocket, ())) => {
483 self.websocket = Some(websocket);
484 let n = self.pending_write_len.take().unwrap_or(0);
485 return Poll::Ready(Ok(n));
486 }
487 Err(e) => return Poll::Ready(Err(make_io_err(e))),
488 }
489 }
490 }
491 }
492 }
493 }
494 }
495
496 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
497 // If a write is in progress, poll it first.
498 match &mut self.write_state {
499 WriteState::Writing(_) => {
500 // let regular poll_write flow handle it; return Pending so caller
501 // should call poll_flush again later. Alternatively, we could
502 // poll it here explicitly, but reusing poll_write semantics is fine.
503 return Poll::Pending;
504 }
505 WriteState::Idle => {
506 // Start a new flush future by taking the websocket
507 let websocket = match self.websocket.take() {
508 Some(websocket) => websocket,
509 None => return Poll::Ready(Ok(())),
510 };
511 // empty payload for close
512 let future = flush(websocket);
513 self.write_state = WriteState::Writing(future);
514
515 // fallthrough to poll the just-created future
516 }
517 }
518
519 // Now poll the write future created above.
520 match &mut self.write_state {
521 WriteState::Writing(fut) => {
522 let mut fut_pin = unsafe { Pin::new_unchecked(fut) };
523 match fut_pin.as_mut().poll(cx) {
524 Poll::Pending => Poll::Pending,
525 Poll::Ready(res) => {
526 self.write_state = WriteState::Idle;
527 match res {
528 Ok((websocket, ())) => {
529 self.websocket = Some(websocket);
530 Poll::Ready(Ok(()))
531 }
532 Err(e) => Poll::Ready(Err(make_io_err(e))),
533 }
534 }
535 }
536 }
537 _ => Poll::Ready(Ok(())),
538 }
539 }
540
541 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
542 // Implement shutdown by sending a Close frame synchronously via the
543 // same state-machine approach: start a write future that sends close.
544 // If a write is already in progress, wait for it to complete first.
545
546 // If a write is in progress, poll it first.
547 match &mut self.write_state {
548 WriteState::Writing(_) => {
549 // let regular poll_write flow handle it; return Pending so caller
550 // should call poll_shutdown again later. Alternatively, we could
551 // poll it here explicitly, but reusing poll_write semantics is fine.
552 return Poll::Pending;
553 }
554 WriteState::Idle => {
555 // start a close write
556 let websocket = match self.websocket.take() {
557 Some(websocket) => websocket,
558 None => return Poll::Ready(Ok(())),
559 };
560 // empty payload for close
561 let future = close(websocket);
562 self.write_state = WriteState::Writing(future);
563
564 // fallthrough to poll the just-created future
565 }
566 }
567
568 // Now poll the write future created above.
569 match &mut self.write_state {
570 WriteState::Writing(fut) => {
571 let mut fut_pin = unsafe { Pin::new_unchecked(fut) };
572 match fut_pin.as_mut().poll(cx) {
573 Poll::Pending => Poll::Pending,
574 Poll::Ready(res) => {
575 self.write_state = WriteState::Idle;
576 match res {
577 Ok((websocket, ())) => {
578 self.websocket = Some(websocket);
579 Poll::Ready(Ok(()))
580 }
581 Err(e) => Poll::Ready(Err(make_io_err(e))),
582 }
583 }
584 }
585 }
586 _ => Poll::Ready(Ok(())),
587 }
588 }
589}
590
591impl<S> Debug for WebSocketStream<S> {
592 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
593 // Helper to stringify read_state/write_state variants without requiring Debug on futures.
594 fn read_state_name<T>(s: &ReadState<T>) -> &'static str {
595 match s {
596 ReadState::Idle => "Idle",
597 ReadState::Reading(_) => "Reading",
598 }
599 }
600
601 fn write_state_name<T>(s: &WriteState<T>) -> &'static str {
602 match s {
603 WriteState::Idle => "Idle",
604 WriteState::Writing(_) => "Writing",
605 }
606 }
607
608 f.debug_struct("WebSocketStream")
609 .field("read_buf_len", &self.read_buf.len())
610 .field("read_state", &read_state_name(&self.read_state))
611 .field("write_state", &write_state_name(&self.write_state))
612 .field("pending_write_len", &self.pending_write_len)
613 .field("closed", &self.closed)
614 .finish()
615 }
616}