blitz_ws/protocol/frame/
core.rs

1//! Utilities to work with raw WebSocket frames.
2
3use std::io::{self, Cursor, Read, Write};
4
5use bytes::{Buf, BytesMut};
6
7use crate::{
8    error::{CapacityError, Error, ProtocolError, Result},
9    protocol::frame::{
10        frame::{Frame, FrameHeader},
11        mask::apply_mask,
12    },
13};
14
15const READ_BUFFER_LENGTH: usize = 128 * 1024;
16
17/// Read buffer size used for `FrameSocket`.
18#[derive(Debug)]
19pub struct FrameSocket<T> {
20    /// The underlying network stream.
21    stream: T,
22    /// Codec for reading/writing frames.
23    codec: FrameCodec,
24}
25
26impl<T: Read + Write> FrameSocket<T> {
27    /// Create a new frame socket.
28    pub fn new(stream: T) -> Self {
29        FrameSocket { stream, codec: FrameCodec::new(READ_BUFFER_LENGTH) }
30    }
31
32    /// Create a new frame socket from partially read data.
33    pub fn from_partially_read(stream: T, part: Vec<u8>) -> Self {
34        FrameSocket { stream, codec: FrameCodec::from_partially_read(part, READ_BUFFER_LENGTH) }
35    }
36
37    /// Extract a stream from the socket.
38    pub fn into_inner(self) -> (T, BytesMut) {
39        (self.stream, self.codec.in_buffer)
40    }
41
42    /// Returns a shared reference to the inner stream.
43    pub fn get_ref(&self) -> &T {
44        &self.stream
45    }
46
47    /// Returns a mutable reference to the inner stream.
48    pub fn get_mut(&mut self) -> &mut T {
49        &mut self.stream
50    }
51
52    /// Read a frame from stream.
53    pub fn read(&mut self, max: Option<usize>) -> Result<Option<Frame>> {
54        self.codec.read(&mut self.stream, max, false, true)
55    }
56
57    /// Writes and immediately flushes a frame.
58    /// Equivalent to calling [`write`](Self::write) then [`flush`](Self::flush).
59    pub fn send(&mut self, frame: Frame) -> Result<()> {
60        self.write(frame)?;
61        self.flush()
62    }
63
64    /// Write a frame to stream.
65    ///
66    /// A subsequent call should be made to [`flush`](Self::flush) to flush writes.
67    ///
68    /// This function guarantees that the frame is queued unless [`Error::WriteBufferFull`]
69    /// is returned.
70    /// In order to handle WouldBlock or Incomplete, call [`flush`](Self::flush) afterwards.
71    pub fn write(&mut self, frame: Frame) -> Result<()> {
72        self.codec.write(&mut self.stream, frame)
73    }
74
75    /// Flush writes.
76    pub fn flush(&mut self) -> Result<()> {
77        self.codec.write_out(&mut self.stream)?;
78        Ok(self.stream.flush()?)
79    }
80}
81
82/// A codec for WebSocket frames.
83#[derive(Debug)]
84pub(crate) struct FrameCodec {
85    /// Buffer to read data from the stream.
86    in_buffer: BytesMut,
87    in_buffer_max_read: usize,
88    /// Buffer to send packets to the network.
89    out_buffer: Vec<u8>,
90    /// Capacity limit for `out_buffer`.
91    max_out_buffer_len: usize,
92    /// Buffer target length to reach before writing to the stream
93    /// on calls to `buffer_frame`.
94    ///
95    /// Setting this to non-zero will buffer small writes from hitting
96    /// the stream.
97    out_buffer_write_len: usize,
98    /// Header and remaining size of the incoming packet being processed.
99    header: Option<(FrameHeader, u64)>,
100}
101
102impl FrameCodec {
103    /// Create a new frame codec.
104    pub(crate) fn new(len: usize) -> Self {
105        Self {
106            in_buffer: BytesMut::with_capacity(len),
107            in_buffer_max_read: len.max(FrameHeader::MAX_HEADER_SIZE),
108            out_buffer: <_>::default(),
109            max_out_buffer_len: usize::MAX,
110            out_buffer_write_len: 0,
111            header: None,
112        }
113    }
114
115    /// Create a new frame codec from partially read data.
116    pub(crate) fn from_partially_read(part: Vec<u8>, min_in_buffer_len: usize) -> Self {
117        let mut buf = BytesMut::from_iter(part);
118        buf.reserve(min_in_buffer_len.saturating_sub(buf.len()));
119
120        Self {
121            in_buffer: buf,
122            in_buffer_max_read: min_in_buffer_len.max(FrameHeader::MAX_HEADER_SIZE),
123            out_buffer: <_>::default(),
124            max_out_buffer_len: usize::MAX,
125            out_buffer_write_len: 0,
126            header: None,
127        }
128    }
129
130    /// Sets a maximum size for the out buffer.
131    pub(crate) fn max_out_buffer_len(&mut self, size: usize) {
132        self.max_out_buffer_len = size
133    }
134
135    /// Sets [`Self::buffer_frame`] buffer target length to reach before
136    /// writing to the stream.
137    pub(crate) fn out_buffer_write_len(&mut self, size: usize) {
138        self.out_buffer_write_len = size
139    }
140
141    /// Read a frame from the provided stream.
142    pub(crate) fn read<S: Read>(
143        &mut self,
144        stream: &mut S,
145        max: Option<usize>,
146        unmask: bool,
147        accept_unmasked: bool,
148    ) -> Result<Option<Frame>> {
149        let max = max.unwrap_or(usize::MAX);
150
151        let mut payload = loop {
152            if self.header.is_none() {
153                let mut cursor = Cursor::new(&mut self.in_buffer);
154                self.header = FrameHeader::parse(&mut cursor)?;
155                let n = cursor.position();
156                Buf::advance(&mut self.in_buffer, n as _);
157
158                if let Some((_, len)) = &self.header {
159                    let len = *len as usize;
160
161                    if len > max {
162                        return Err(Error::Capacity(CapacityError::MessageTooLarge {
163                            size: len,
164                            max,
165                        }));
166                    }
167
168                    self.in_buffer.reserve(len);
169                } else {
170                    self.in_buffer.reserve(FrameHeader::MAX_HEADER_SIZE);
171                }
172            }
173
174            if let Some((_, len)) = &self.header {
175                let len = *len as usize;
176                if len <= self.in_buffer.len() {
177                    break self.in_buffer.split_to(len);
178                }
179            }
180
181            if self.read_in(stream)? == 0 {
182                return Ok(None);
183            }
184        };
185
186        let (mut header, length) = self.header.take().expect("Bug: no frame header");
187        debug_assert_eq!(payload.len() as u64, length);
188
189        if unmask {
190            if let Some(mask) = header.mask.take() {
191                apply_mask(&mut payload, mask);
192            } else if !accept_unmasked {
193                return Err(Error::Protocol(ProtocolError::UnmaskedFrameFromClient));
194            }
195        }
196
197        let frame = Frame::new(header, payload.freeze());
198        Ok(Some(frame))
199    }
200
201    /// Read into available `in_buffer` capacity.
202    fn read_in<S: Read>(&mut self, stream: &mut S) -> io::Result<usize> {
203        let len = self.in_buffer.len();
204        debug_assert!(self.in_buffer.capacity() > len);
205
206        self.in_buffer.resize(self.in_buffer.capacity().min(len + self.in_buffer_max_read), 0);
207
208        let size = stream.read(&mut self.in_buffer[len..]);
209        self.in_buffer.truncate(len + size.as_ref().copied().unwrap_or(0));
210
211        size
212    }
213
214    /// Writes a frame into the `out_buffer`.
215    /// If the out buffer size is over the `out_buffer_write_len` will also write
216    /// the out buffer into the provided `stream`.
217    ///
218    /// To ensure buffered frames are written call [`Self::write_out_buffer`].
219    ///
220    /// May write to the stream, will **not** flush.
221    pub(crate) fn write<S: Write>(&mut self, stream: &mut S, frame: Frame) -> Result<()> {
222        if frame.len() + self.out_buffer.len() > self.max_out_buffer_len {
223            return Err(Error::WriteBufferFull);
224        }
225
226        self.out_buffer.reserve(frame.len());
227        frame.into_buf(&mut self.out_buffer).expect("Bug: can't write to vector");
228
229        if self.out_buffer.len() > self.out_buffer_write_len {
230            self.write_out(stream)
231        } else {
232            Ok(())
233        }
234    }
235
236    /// Writes the out_buffer to the provided stream.
237    ///
238    /// Does **not** flush.
239    pub(crate) fn write_out<S: Write>(&mut self, stream: &mut S) -> Result<()> {
240        while !self.out_buffer.is_empty() {
241            let len = stream.write(&self.out_buffer)?;
242
243            if len == 0 {
244                return Err(io::Error::new(
245                    io::ErrorKind::ConnectionReset,
246                    "Connection reset while sending",
247                )
248                .into());
249            }
250
251            self.out_buffer.drain(0..len);
252        }
253
254        Ok(())
255    }
256}