1use std::{
4 fmt::Display,
5 io::{Cursor, ErrorKind, Read, Write},
6 mem,
7 result::Result as StdResult,
8 str::Utf8Error,
9};
10
11use bytes::{Bytes, BytesMut};
12
13use super::{
14 codec::{CloseCode, Control, Data, OpCode},
15 mask::{apply_mask, generate},
16};
17use crate::{
18 error::{Error, ProtocolError, Result},
19 protocol::frame::Utf8Bytes,
20};
21
22#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct CloseFrame {
25 pub code: CloseCode,
27 pub reason: Utf8Bytes,
29}
30
31impl Display for CloseFrame {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 write!(f, "{} ({})", self.reason, self.code)
34 }
35}
36
37#[allow(missing_copy_implementations)]
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub struct FrameHeader {
41 pub fin: bool,
43 pub rsv1: bool,
45 pub rsv2: bool,
47 pub rsv3: bool,
49 pub opcode: OpCode,
51 pub mask: Option<[u8; 4]>,
53}
54
55impl Default for FrameHeader {
56 fn default() -> Self {
57 FrameHeader {
58 fin: false,
59 rsv1: false,
60 rsv2: false,
61 rsv3: false,
62 opcode: OpCode::Control(Control::Close),
63 mask: None,
64 }
65 }
66}
67
68impl FrameHeader {
69 pub(crate) const MAX_HEADER_SIZE: usize = 14;
72
73 pub fn parse(cursor: &mut Cursor<impl AsRef<[u8]>>) -> Result<Option<(Self, u64)>> {
77 let init = cursor.position();
78
79 match Self::parse_internal(cursor) {
80 i @ Ok(None) => {
81 cursor.set_position(init);
82 i
83 }
84 other => other,
85 }
86 }
87
88 #[allow(clippy::len_without_is_empty)]
90 pub fn len(&self, length: u64) -> usize {
91 2 + Length::for_len(length).additional() + (if self.mask.is_some() { 4 } else { 0 })
92 }
93
94 pub fn format(&self, length: u64, output: &mut impl Write) -> Result<()> {
96 let code: u8 = self.opcode.into();
97
98 let first_byte = {
99 code | if self.fin { 0x80 } else { 0 }
100 | if self.rsv1 { 0x40 } else { 0 }
101 | if self.rsv2 { 0x20 } else { 0 }
102 | if self.rsv3 { 0x10 } else { 0 }
103 };
104
105 let len = Length::for_len(length);
106
107 let second_byte = len.len_byte() | if self.mask.is_some() { 0x80 } else { 0 };
108
109 output.write_all(&[first_byte, second_byte])?;
110
111 match len {
112 Length::U8(_) => (),
113 Length::U16 => {
114 output.write_all(&(length as u16).to_be_bytes())?;
115 }
116 Length::U64 => {
117 output.write_all(&length.to_be_bytes())?;
118 }
119 }
120
121 if let Some(ref mask) = self.mask {
122 output.write_all(mask)?;
123 }
124
125 Ok(())
126 }
127
128 pub(crate) fn set_random_mask(&mut self) {
132 self.mask = Some(generate());
133 }
134
135 fn parse_internal(cursor: &mut impl Read) -> Result<Option<(Self, u64)>> {
139 let (a, b) = {
140 let mut head = [0u8; 2];
141 if cursor.read(&mut head)? != 2 {
142 return Ok(None);
143 }
144
145 (head[0], head[1])
146 };
147
148 let fin = a & 0x80 != 0;
149 let rsv1 = a & 0x40 != 0;
150 let rsv2 = a & 0x20 != 0;
151 let rsv3 = a & 0x10 != 0;
152
153 let opcode = OpCode::from(a & 0x0F);
154
155 let masked = b & 0x80 != 0;
156
157 let len = {
158 let len_byte = b & 0x7F;
159 let particular_len = Length::for_byte(len_byte).additional();
160
161 if particular_len > 0 {
162 const SIZE: usize = mem::size_of::<u64>();
163 assert!(
164 particular_len < SIZE,
165 "Length exceeded max size of unsigned 64-bit integer"
166 );
167
168 let start = SIZE - particular_len;
169 let mut buf = [0u8; SIZE];
170
171 match cursor.read_exact(&mut buf[start..]) {
172 Err(ref e) if e.kind() == ErrorKind::UnexpectedEof => return Ok(None),
173 Err(e) => return Err(e.into()),
174 Ok(()) => u64::from_be_bytes(buf),
175 }
176 } else {
177 u64::from(len_byte)
178 }
179 };
180
181 let mask = if masked {
182 let mut mask_bytes = [0u8; 4];
183 if cursor.read(&mut mask_bytes)? != 4 {
184 return Ok(None);
185 } else {
186 Some(mask_bytes)
187 }
188 } else {
189 None
190 };
191
192 match opcode {
193 OpCode::Control(Control::Reserved(_)) => {
194 return Err(Error::Protocol(ProtocolError::UnknownControlOpCode(a & 0x0F)));
195 }
196 OpCode::Data(Data::Reserved(_)) => {
197 return Err(Error::Protocol(ProtocolError::UnknownDataOpCode(a & 0x0F)));
198 }
199 _ => (),
200 };
201
202 let header = FrameHeader { fin, rsv1, rsv2, rsv3, opcode, mask };
203
204 Ok(Some((header, len)))
205 }
206}
207
208impl Frame {}
209
210#[derive(Debug, Clone, PartialEq, Eq)]
212pub struct Frame {
213 header: FrameHeader,
214 payload: Bytes,
215}
216
217impl Frame {
218 #[inline]
221 pub fn len(&self) -> usize {
222 let length = self.payload.len();
223 self.header.len(length as u64) + length
224 }
225
226 #[inline]
228 pub fn is_empty(&self) -> bool {
229 self.len() == 0
230 }
231
232 #[inline]
234 pub fn header(&self) -> &FrameHeader {
235 &self.header
236 }
237
238 #[inline]
240 pub fn header_mut(&mut self) -> &mut FrameHeader {
241 &mut self.header
242 }
243
244 #[inline]
246 pub fn payload(&self) -> &[u8] {
247 &self.payload
248 }
249
250 #[inline]
252 pub(crate) fn is_masked(&self) -> bool {
253 self.header.mask.is_some()
254 }
255
256 #[inline]
261 pub(crate) fn set_random_mask(&mut self) {
262 self.header.set_random_mask();
263 }
264
265 #[inline]
267 pub fn into_text(self) -> StdResult<Utf8Bytes, Utf8Error> {
268 self.payload.try_into()
269 }
270
271 #[inline]
273 pub fn into_payload(self) -> Bytes {
274 self.payload
275 }
276
277 #[inline]
279 pub fn to_text(&self) -> Result<&str, Utf8Error> {
280 std::str::from_utf8(&self.payload)
281 }
282
283 #[inline]
285 pub(crate) fn into_close(self) -> Result<Option<CloseFrame>> {
286 match self.payload.len() {
287 0 => Ok(None),
288 1 => Err(Error::Protocol(ProtocolError::InvalidCloseFrame)),
289 _ => {
290 let code = u16::from_be_bytes([self.payload[0], self.payload[1]]).into();
291 let reason = Utf8Bytes::try_from(self.payload.slice(2..))?;
292
293 Ok(Some(CloseFrame { code, reason }))
294 }
295 }
296 }
297
298 #[inline]
300 pub fn new_data(data: impl Into<Bytes>, opcode: OpCode, fin: bool) -> Frame {
301 debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame");
302
303 Frame { header: FrameHeader { fin, opcode, ..Default::default() }, payload: data.into() }
304 }
305
306 #[inline]
308 pub fn new_ping(data: impl Into<Bytes>) -> Frame {
309 Frame {
310 header: FrameHeader { opcode: OpCode::Control(Control::Ping), ..<_>::default() },
311 payload: data.into(),
312 }
313 }
314
315 #[inline]
317 pub fn new_pong(data: impl Into<Bytes>) -> Frame {
318 Frame {
319 header: FrameHeader { opcode: OpCode::Control(Control::Pong), ..<_>::default() },
320 payload: data.into(),
321 }
322 }
323
324 #[inline]
326 pub fn new_close(msg: Option<CloseFrame>) -> Frame {
327 let payload = if let Some(CloseFrame { code, reason }) = msg {
328 let mut p = BytesMut::with_capacity(reason.len() + 2);
329 p.extend(u16::from(code).to_be_bytes());
330 p.extend_from_slice(reason.as_bytes());
331 p
332 } else {
333 <_>::default()
334 };
335
336 Frame { header: <_>::default(), payload: payload.into() }
337 }
338
339 pub fn new(header: FrameHeader, payload: Bytes) -> Self {
341 Frame { header, payload }
342 }
343
344 pub fn format_to_buf(mut self, output: &mut impl Write) -> Result<()> {
346 self.header.format(self.payload.len() as u64, output)?;
347
348 if let Some(mask) = self.header.mask.take() {
349 let mut data = Vec::from(mem::take(&mut self.payload));
350 apply_mask(&mut data, mask);
351
352 output.write_all(&data)?;
353 } else {
354 output.write_all(&self.payload)?;
355 }
356
357 Ok(())
358 }
359
360 pub(crate) fn into_buf(mut self, buf: &mut Vec<u8>) -> Result<()> {
361 self.header.format(self.payload.len() as u64, buf)?;
362
363 let len = buf.len();
364 buf.extend_from_slice(&self.payload);
365
366 if let Some(mask) = self.header.mask.take() {
367 apply_mask(&mut buf[len..], mask);
368 }
369
370 Ok(())
371 }
372}
373
374impl Display for Frame {
375 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376 use std::fmt::Write;
377
378 write!(
379 f,
380 "/
381 [FRAME]
382 final: {},
383 reserved: {} {} {},
384 opcode: {},
385 length: {},
386 payload-length: {},
387 payload: 0x{}
388 ",
389 self.header.fin,
390 self.header.rsv1,
391 self.header.rsv2,
392 self.header.rsv3,
393 self.header.opcode,
394 self.len(),
395 self.payload.len(),
396 self.payload.iter().fold(String::new(), |mut out, byte| {
397 _ = write!(out, "{byte:02x}");
398 out
399 })
400 )
401 }
402}
403
404enum Length {
405 U8(u8),
406 U16,
407 U64,
408}
409
410impl Length {
411 #[inline]
412 fn for_len(len: u64) -> Self {
413 if len < 126 {
414 Length::U8(len as u8)
415 } else if len < 65536 {
416 Length::U16
417 } else {
418 Length::U64
419 }
420 }
421
422 #[inline]
423 fn additional(&self) -> usize {
424 match *self {
425 Self::U8(_) => 0,
426 Self::U16 => 2,
427 Self::U64 => 8,
428 }
429 }
430
431 #[inline]
432 fn len_byte(&self) -> u8 {
433 match *self {
434 Self::U8(b) => b,
435 Self::U16 => 126,
436 Self::U64 => 127,
437 }
438 }
439
440 #[inline]
441 fn for_byte(byte: u8) -> Self {
442 match byte & 0x7F {
443 126 => Length::U16,
444 127 => Length::U64,
445 b => Length::U8(b),
446 }
447 }
448}