1use crate::{
7 error::{Error, FrameError, Result},
8 protocol::{frame::*, Opcode},
9};
10use bytes::{Buf, BufMut, Bytes, BytesMut};
11
12#[derive(Debug, Clone)]
14pub struct Frame {
15 pub fin: bool,
17 pub rsv: [bool; 3],
19 pub opcode: Opcode,
21 pub masked: bool,
23 pub mask: Option<[u8; 4]>,
25 pub payload: Bytes,
27}
28
29impl Frame {
30 pub fn new(opcode: Opcode, payload: impl Into<Bytes>) -> Self {
32 Self {
33 fin: true,
34 rsv: [false; 3],
35 opcode,
36 masked: false,
37 mask: None,
38 payload: payload.into(),
39 }
40 }
41
42 pub fn continuation(payload: impl Into<Bytes>) -> Self {
44 Self::new(Opcode::Continuation, payload)
45 }
46
47 pub fn text(payload: impl Into<Bytes>) -> Self {
49 Self::new(Opcode::Text, payload)
50 }
51
52 pub fn binary(payload: impl Into<Bytes>) -> Self {
54 Self::new(Opcode::Binary, payload)
55 }
56
57 pub fn close(code: Option<u16>, reason: Option<&str>) -> Self {
59 let mut payload = BytesMut::new();
60
61 if let Some(code) = code {
62 payload.put_u16(code);
63 }
64
65 if let Some(reason) = reason {
66 payload.put_slice(reason.as_bytes());
67 }
68
69 Self::new(Opcode::Close, payload.freeze())
70 }
71
72 pub fn ping(payload: impl Into<Bytes>) -> Self {
74 Self::new(Opcode::Ping, payload)
75 }
76
77 pub fn pong(payload: impl Into<Bytes>) -> Self {
79 Self::new(Opcode::Pong, payload)
80 }
81
82 pub fn fin(mut self, fin: bool) -> Self {
84 self.fin = fin;
85 self
86 }
87
88 pub fn rsv(mut self, rsv1: bool, rsv2: bool, rsv3: bool) -> Self {
90 self.rsv = [rsv1, rsv2, rsv3];
91 self
92 }
93
94 pub fn mask(mut self, enabled: bool) -> Self {
96 if enabled && !self.masked {
97 let mask = rand::random::<[u8; 4]>();
98 self.payload = mask_bytes(&self.payload, &mask);
99 self.masked = true;
100 self.mask = Some(mask);
101 } else if !enabled && self.masked {
102 if let Some(mask) = self.mask {
104 self.payload = mask_bytes(&self.payload, &mask);
105 }
106 self.masked = false;
107 self.mask = None;
108 }
109 self
110 }
111
112 #[cfg(feature = "compression")]
114 pub fn compress(mut self, enabled: bool) -> Self {
115 if enabled && self.opcode.is_data() && !self.rsv[0] {
116 use flate2::write::DeflateEncoder;
117 use flate2::Compression;
118 use std::io::Write;
119
120 let mut encoder = DeflateEncoder::new(Vec::new(), Compression::new(6));
121 if encoder.write_all(&self.payload).is_ok() && encoder.flush().is_ok() {
122 if let Ok(compressed) = encoder.finish() {
123 self.payload = Bytes::from(compressed);
124 self.rsv[0] = true;
125 }
126 }
127 }
128 self
129 }
130
131 pub fn to_bytes(&self) -> Bytes {
133 let mut buf = BytesMut::new();
134 self.write_to(&mut buf);
135 buf.freeze()
136 }
137
138 pub fn write_to(&self, buf: &mut BytesMut) {
140 let first_byte = ((self.fin as u8) << 7)
142 | ((self.rsv[0] as u8) << 6)
143 | ((self.rsv[1] as u8) << 5)
144 | ((self.rsv[2] as u8) << 4)
145 | self.opcode.value();
146 buf.put_u8(first_byte);
147
148 let payload_len = self.payload.len();
150 let mask_bit = (self.masked as u8) << 7;
151
152 if payload_len < 126 {
153 buf.put_u8(mask_bit | payload_len as u8);
154 } else if payload_len <= u16::MAX as usize {
155 buf.put_u8(mask_bit | PAYLOAD_LEN_16);
156 buf.put_u16(payload_len as u16);
157 } else {
158 buf.put_u8(mask_bit | PAYLOAD_LEN_64);
159 buf.put_u64(payload_len as u64);
160 }
161
162 if let Some(mask) = self.mask {
164 buf.put_slice(&mask);
165 }
166
167 buf.put_slice(&self.payload);
169 }
170
171 pub fn parse(buf: &mut BytesMut, compression_enabled: bool) -> Result<Self> {
173 if buf.len() < 2 {
174 return Err(FrameError::InsufficientData {
175 needed: 2,
176 have: buf.len(),
177 }
178 .into());
179 }
180
181 let mut cursor = std::io::Cursor::new(&buf[..]);
182
183 let first_byte = cursor.get_u8();
185 let fin = (first_byte & FIN_BIT) != 0;
186 let rsv1 = (first_byte & RSV1_BIT) != 0;
187 let rsv2 = (first_byte & RSV2_BIT) != 0;
188 let rsv3 = (first_byte & RSV3_BIT) != 0;
189 let opcode = Opcode::from(first_byte & OPCODE_MASK)
190 .ok_or(FrameError::InvalidOpcode(first_byte & OPCODE_MASK))?;
191
192 let second_byte = cursor.get_u8();
194 let masked = (second_byte & MASK_BIT) != 0;
195 let mut payload_len = (second_byte & PAYLOAD_LEN_MASK) as usize;
196
197 if payload_len == 126 {
199 if buf.len() < 4 {
200 return Err(FrameError::InsufficientData {
201 needed: 4,
202 have: buf.len(),
203 }
204 .into());
205 }
206 payload_len = cursor.get_u16() as usize;
207 } else if payload_len == 127 {
208 if buf.len() < 10 {
209 return Err(FrameError::InsufficientData {
210 needed: 10,
211 have: buf.len(),
212 }
213 .into());
214 }
215 payload_len = cursor.get_u64() as usize;
216 }
217
218 let mask = if masked {
220 if buf.len() < cursor.position() as usize + 4 + payload_len {
221 return Err(FrameError::InsufficientData {
222 needed: cursor.position() as usize + 4 + payload_len,
223 have: buf.len(),
224 }
225 .into());
226 }
227 let mut mask = [0u8; 4];
228 cursor.copy_to_slice(&mut mask);
229 Some(mask)
230 } else {
231 None
232 };
233
234 if buf.len() < cursor.position() as usize + payload_len {
236 return Err(FrameError::InsufficientData {
237 needed: cursor.position() as usize + payload_len,
238 have: buf.len(),
239 }
240 .into());
241 }
242
243 let mut payload = Bytes::copy_from_slice(
244 &buf[cursor.position() as usize..cursor.position() as usize + payload_len],
245 );
246
247 if let Some(mask) = mask {
249 payload = mask_bytes(&payload, &mask);
250 }
251
252 #[cfg(feature = "compression")]
254 if rsv1 && compression_enabled {
255 use flate2::read::DeflateDecoder;
256 use std::io::Read;
257
258 let mut decoder = DeflateDecoder::new(&payload[..]);
259 let mut decompressed = Vec::new();
260 if decoder.read_to_end(&mut decompressed).is_err() {
261 return Err(FrameError::DecompressionFailed.into());
262 }
263 payload = Bytes::from(decompressed);
264 }
265
266 let frame_len = cursor.position() as usize + payload_len;
268 buf.advance(frame_len);
269
270 if opcode.is_control() && !fin {
272 return Err(FrameError::FragmentedControlFrame.into());
273 }
274
275 if (rsv1 && !(compression_enabled && opcode.is_data())) || rsv2 || rsv3 {
276 return Err(FrameError::ReservedBitsSet.into());
277 }
278
279 Ok(Frame {
280 fin,
281 rsv: [rsv1, rsv2, rsv3],
282 opcode,
283 masked,
284 mask,
285 payload,
286 })
287 }
288
289 pub fn kind(&self) -> FrameKind {
291 match self.opcode {
292 Opcode::Text => FrameKind::Text,
293 Opcode::Binary => FrameKind::Binary,
294 Opcode::Close => FrameKind::Close,
295 Opcode::Ping => FrameKind::Ping,
296 Opcode::Pong => FrameKind::Pong,
297 Opcode::Continuation => FrameKind::Continuation,
298 _ => FrameKind::Reserved,
299 }
300 }
301
302 pub fn payload_len(&self) -> usize {
304 self.payload.len()
305 }
306
307 pub fn is_control(&self) -> bool {
309 self.opcode.is_control()
310 }
311
312 pub fn is_data(&self) -> bool {
314 self.opcode.is_data()
315 }
316
317 pub fn is_final(&self) -> bool {
319 self.fin
320 }
321}
322
323#[derive(Debug, Clone, Copy, PartialEq, Eq)]
325pub enum FrameKind {
326 Text,
328 Binary,
330 Close,
332 Ping,
334 Pong,
336 Continuation,
338 Reserved,
340}
341
342fn mask_bytes(data: &[u8], mask: &[u8; 4]) -> Bytes {
344 let mut masked = BytesMut::with_capacity(data.len());
345 for (i, &byte) in data.iter().enumerate() {
346 masked.put_u8(byte ^ mask[i % 4]);
347 }
348 masked.freeze()
349}
350
351#[derive(Debug)]
353pub struct FrameParser {
354 buffer: BytesMut,
356 expected_size: Option<usize>,
358 compression_enabled: bool,
360}
361
362impl Default for FrameParser {
363 fn default() -> Self {
364 Self {
365 buffer: BytesMut::new(),
366 expected_size: None,
367 compression_enabled: false,
368 }
369 }
370}
371
372impl FrameParser {
373 pub fn new() -> Self {
375 Self::default()
376 }
377
378 pub fn with_compression(compression_enabled: bool) -> Self {
380 Self {
381 buffer: BytesMut::new(),
382 expected_size: None,
383 compression_enabled,
384 }
385 }
386
387 pub fn feed(&mut self, data: &[u8]) -> Vec<Result<Frame>> {
389 self.buffer.extend_from_slice(data);
390 self.extract_frames()
391 }
392
393 fn extract_frames(&mut self) -> Vec<Result<Frame>> {
395 let mut frames = Vec::new();
396
397 while let Some(frame) = self.try_parse_frame() {
398 match frame {
399 Ok(f) => frames.push(Ok(f)),
400 Err(e) => {
401 frames.push(Err(e));
402 break;
403 }
404 }
405 }
406
407 frames
408 }
409
410 fn try_parse_frame(&mut self) -> Option<Result<Frame>> {
412 let mut buf = self.buffer.clone();
413
414 match Frame::parse(&mut buf, self.compression_enabled) {
415 Ok(frame) => {
416 let parsed_len = self.buffer.len() - buf.len();
418 self.buffer.advance(parsed_len);
419 Some(Ok(frame))
420 }
421 Err(Error::Frame(FrameError::InsufficientData { .. })) => {
422 None
424 }
425 Err(e) => {
426 self.buffer.clear();
428 Some(Err(e))
429 }
430 }
431 }
432
433 pub fn buffered_bytes(&self) -> usize {
435 self.buffer.len()
436 }
437
438 pub fn clear(&mut self) {
440 self.buffer.clear();
441 self.expected_size = None;
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448
449 #[test]
450 fn test_text_frame_serialization() {
451 let frame = Frame::text("hello");
452 let bytes = frame.to_bytes();
453
454 assert_eq!(bytes[0], 0x81); assert_eq!(bytes[1], 0x05); assert_eq!(&bytes[2..], b"hello");
457 }
458
459 #[test]
460 fn test_masked_frame() {
461 let frame = Frame::text("hello").mask(true);
462 let bytes = frame.to_bytes();
463
464 assert_eq!(bytes[1] & 0x80, 0x80); assert_eq!(bytes.len(), 2 + 4 + 5); }
467
468 #[test]
469 fn test_frame_parsing() {
470 let original = Frame::text("hello");
471 let bytes = original.to_bytes();
472 let mut buf = BytesMut::from(&bytes[..]);
473
474 let parsed = Frame::parse(&mut buf, false).unwrap();
475 assert_eq!(parsed.kind(), FrameKind::Text);
476 assert_eq!(parsed.payload, "hello");
477 assert!(buf.is_empty());
478 }
479
480 #[test]
481 fn test_large_frame() {
482 let payload = vec![0u8; 65536]; let frame = Frame::binary(payload.clone());
484 let bytes = frame.to_bytes();
485
486 assert_eq!(bytes[1], 127); assert_eq!(bytes[2..10], (65536u64).to_be_bytes());
488 }
489
490 #[test]
491 fn test_close_frame() {
492 let frame = Frame::close(Some(1000), Some("Goodbye"));
493 let bytes = frame.to_bytes();
494
495 assert_eq!(bytes[0], 0x88); assert_eq!(bytes[1], 0x09); assert_eq!(&bytes[2..4], 1000u16.to_be_bytes());
498 assert_eq!(&bytes[4..], b"Goodbye");
499 assert_eq!(bytes.len(), 11); }
501
502 #[test]
503 fn test_frame_parser() {
504 let mut parser = FrameParser::new();
505
506 let frame1 = Frame::text("frame1");
507 let frame2 = Frame::ping("ping");
508
509 let bytes1 = frame1.to_bytes();
510 let bytes2 = frame2.to_bytes();
511
512 let frames = parser.feed(&bytes1[..5]);
514 assert_eq!(frames.len(), 0); let frames = parser.feed(&bytes1[5..]);
518 assert_eq!(frames.len(), 1);
519 assert!(frames[0].as_ref().unwrap().is_data());
520
521 let frames = parser.feed(&bytes2);
523 assert_eq!(frames.len(), 1);
524 assert!(frames[0].as_ref().unwrap().is_control());
525 }
526}