blitz_ws/protocol/frame/
core.rs1use std::io::{self, Cursor, Read, Write};
4
5use bytes::{Buf, BytesMut};
6
7use crate::{
8 error::{CapacityError, Error, ProtocolError, Result},
9 protocol::frame::{
10 frame::{Frame, FrameHeader},
11 mask::apply_mask,
12 },
13};
14
15const READ_BUFFER_LENGTH: usize = 128 * 1024;
16
17#[derive(Debug)]
19pub struct FrameSocket<T> {
20 stream: T,
22 codec: FrameCodec,
24}
25
26impl<T: Read + Write> FrameSocket<T> {
27 pub fn new(stream: T) -> Self {
29 FrameSocket { stream, codec: FrameCodec::new(READ_BUFFER_LENGTH) }
30 }
31
32 pub fn from_partially_read(stream: T, part: Vec<u8>) -> Self {
34 FrameSocket { stream, codec: FrameCodec::from_partially_read(part, READ_BUFFER_LENGTH) }
35 }
36
37 pub fn into_inner(self) -> (T, BytesMut) {
39 (self.stream, self.codec.in_buffer)
40 }
41
42 pub fn get_ref(&self) -> &T {
44 &self.stream
45 }
46
47 pub fn get_mut(&mut self) -> &mut T {
49 &mut self.stream
50 }
51
52 pub fn read(&mut self, max: Option<usize>) -> Result<Option<Frame>> {
54 self.codec.read(&mut self.stream, max, false, true)
55 }
56
57 pub fn send(&mut self, frame: Frame) -> Result<()> {
60 self.write(frame)?;
61 self.flush()
62 }
63
64 pub fn write(&mut self, frame: Frame) -> Result<()> {
72 self.codec.write(&mut self.stream, frame)
73 }
74
75 pub fn flush(&mut self) -> Result<()> {
77 self.codec.write_out(&mut self.stream)?;
78 Ok(self.stream.flush()?)
79 }
80}
81
82#[derive(Debug)]
84pub(crate) struct FrameCodec {
85 in_buffer: BytesMut,
87 in_buffer_max_read: usize,
88 out_buffer: Vec<u8>,
90 max_out_buffer_len: usize,
92 out_buffer_write_len: usize,
98 header: Option<(FrameHeader, u64)>,
100}
101
102impl FrameCodec {
103 pub(crate) fn new(len: usize) -> Self {
105 Self {
106 in_buffer: BytesMut::with_capacity(len),
107 in_buffer_max_read: len.max(FrameHeader::MAX_HEADER_SIZE),
108 out_buffer: <_>::default(),
109 max_out_buffer_len: usize::MAX,
110 out_buffer_write_len: 0,
111 header: None,
112 }
113 }
114
115 pub(crate) fn from_partially_read(part: Vec<u8>, min_in_buffer_len: usize) -> Self {
117 let mut buf = BytesMut::from_iter(part);
118 buf.reserve(min_in_buffer_len.saturating_sub(buf.len()));
119
120 Self {
121 in_buffer: buf,
122 in_buffer_max_read: min_in_buffer_len.max(FrameHeader::MAX_HEADER_SIZE),
123 out_buffer: <_>::default(),
124 max_out_buffer_len: usize::MAX,
125 out_buffer_write_len: 0,
126 header: None,
127 }
128 }
129
130 pub(crate) fn max_out_buffer_len(&mut self, size: usize) {
132 self.max_out_buffer_len = size
133 }
134
135 pub(crate) fn out_buffer_write_len(&mut self, size: usize) {
138 self.out_buffer_write_len = size
139 }
140
141 pub(crate) fn read<S: Read>(
143 &mut self,
144 stream: &mut S,
145 max: Option<usize>,
146 unmask: bool,
147 accept_unmasked: bool,
148 ) -> Result<Option<Frame>> {
149 let max = max.unwrap_or(usize::MAX);
150
151 let mut payload = loop {
152 if self.header.is_none() {
153 let mut cursor = Cursor::new(&mut self.in_buffer);
154 self.header = FrameHeader::parse(&mut cursor)?;
155 let n = cursor.position();
156 Buf::advance(&mut self.in_buffer, n as _);
157
158 if let Some((_, len)) = &self.header {
159 let len = *len as usize;
160
161 if len > max {
162 return Err(Error::Capacity(CapacityError::MessageTooLarge {
163 size: len,
164 max,
165 }));
166 }
167
168 self.in_buffer.reserve(len);
169 } else {
170 self.in_buffer.reserve(FrameHeader::MAX_HEADER_SIZE);
171 }
172 }
173
174 if let Some((_, len)) = &self.header {
175 let len = *len as usize;
176 if len <= self.in_buffer.len() {
177 break self.in_buffer.split_to(len);
178 }
179 }
180
181 if self.read_in(stream)? == 0 {
182 return Ok(None);
183 }
184 };
185
186 let (mut header, length) = self.header.take().expect("Bug: no frame header");
187 debug_assert_eq!(payload.len() as u64, length);
188
189 if unmask {
190 if let Some(mask) = header.mask.take() {
191 apply_mask(&mut payload, mask);
192 } else if !accept_unmasked {
193 return Err(Error::Protocol(ProtocolError::UnmaskedFrameFromClient));
194 }
195 }
196
197 let frame = Frame::new(header, payload.freeze());
198 Ok(Some(frame))
199 }
200
201 fn read_in<S: Read>(&mut self, stream: &mut S) -> io::Result<usize> {
203 let len = self.in_buffer.len();
204 debug_assert!(self.in_buffer.capacity() > len);
205
206 self.in_buffer.resize(self.in_buffer.capacity().min(len + self.in_buffer_max_read), 0);
207
208 let size = stream.read(&mut self.in_buffer[len..]);
209 self.in_buffer.truncate(len + size.as_ref().copied().unwrap_or(0));
210
211 size
212 }
213
214 pub(crate) fn write<S: Write>(&mut self, stream: &mut S, frame: Frame) -> Result<()> {
222 if frame.len() + self.out_buffer.len() > self.max_out_buffer_len {
223 return Err(Error::WriteBufferFull);
224 }
225
226 self.out_buffer.reserve(frame.len());
227 frame.into_buf(&mut self.out_buffer).expect("Bug: can't write to vector");
228
229 if self.out_buffer.len() > self.out_buffer_write_len {
230 self.write_out(stream)
231 } else {
232 Ok(())
233 }
234 }
235
236 pub(crate) fn write_out<S: Write>(&mut self, stream: &mut S) -> Result<()> {
240 while !self.out_buffer.is_empty() {
241 let len = stream.write(&self.out_buffer)?;
242
243 if len == 0 {
244 return Err(io::Error::new(
245 io::ErrorKind::ConnectionReset,
246 "Connection reset while sending",
247 )
248 .into());
249 }
250
251 self.out_buffer.drain(0..len);
252 }
253
254 Ok(())
255 }
256}