1use crate::{
7 error::{Error, FrameError, Result},
8 protocol::{frame::*, Opcode},
9};
10use bytes::{Buf, BufMut, Bytes, BytesMut};
11
12#[derive(Debug, Clone)]
14pub struct Frame {
15 pub fin: bool,
17 pub rsv: [bool; 3],
19 pub opcode: Opcode,
21 pub masked: bool,
23 pub mask: Option<[u8; 4]>,
25 pub payload: Bytes,
27}
28
29impl Frame {
30 pub fn new(opcode: Opcode, payload: impl Into<Bytes>) -> Self {
32 Self {
33 fin: true,
34 rsv: [false; 3],
35 opcode,
36 masked: false,
37 mask: None,
38 payload: payload.into(),
39 }
40 }
41
42 pub fn continuation(payload: impl Into<Bytes>) -> Self {
44 Self::new(Opcode::Continuation, payload)
45 }
46
47 pub fn text(payload: impl Into<Bytes>) -> Self {
49 Self::new(Opcode::Text, payload)
50 }
51
52 pub fn binary(payload: impl Into<Bytes>) -> Self {
54 Self::new(Opcode::Binary, payload)
55 }
56
57 pub fn close(code: Option<u16>, reason: Option<&str>) -> Self {
59 let mut payload = BytesMut::new();
60
61 if let Some(code) = code {
62 payload.put_u16(code);
63 }
64
65 if let Some(reason) = reason {
66 payload.put_slice(reason.as_bytes());
67 }
68
69 Self::new(Opcode::Close, payload.freeze())
70 }
71
72 pub fn ping(payload: impl Into<Bytes>) -> Self {
74 Self::new(Opcode::Ping, payload)
75 }
76
77 pub fn pong(payload: impl Into<Bytes>) -> Self {
79 Self::new(Opcode::Pong, payload)
80 }
81
82 pub fn fin(mut self, fin: bool) -> Self {
84 self.fin = fin;
85 self
86 }
87
88 pub fn rsv(mut self, rsv1: bool, rsv2: bool, rsv3: bool) -> Self {
90 self.rsv = [rsv1, rsv2, rsv3];
91 self
92 }
93
94 pub fn mask(mut self, enabled: bool) -> Self {
96 if enabled && !self.masked {
97 let mask = rand::random::<[u8; 4]>();
98 self.payload = mask_bytes(&self.payload, &mask);
99 self.masked = true;
100 self.mask = Some(mask);
101 } else if !enabled && self.masked {
102 if let Some(mask) = self.mask {
104 self.payload = mask_bytes(&self.payload, &mask);
105 }
106 self.masked = false;
107 self.mask = None;
108 }
109 self
110 }
111
112 pub fn to_bytes(&self) -> Bytes {
114 let mut buf = BytesMut::new();
115 self.write_to(&mut buf);
116 buf.freeze()
117 }
118
119 pub fn write_to(&self, buf: &mut BytesMut) {
121 let first_byte = ((self.fin as u8) << 7)
123 | ((self.rsv[0] as u8) << 6)
124 | ((self.rsv[1] as u8) << 5)
125 | ((self.rsv[2] as u8) << 4)
126 | self.opcode.value();
127 buf.put_u8(first_byte);
128
129 let payload_len = self.payload.len();
131 let mask_bit = (self.masked as u8) << 7;
132
133 if payload_len < 126 {
134 buf.put_u8(mask_bit | payload_len as u8);
135 } else if payload_len <= u16::MAX as usize {
136 buf.put_u8(mask_bit | PAYLOAD_LEN_16);
137 buf.put_u16(payload_len as u16);
138 } else {
139 buf.put_u8(mask_bit | PAYLOAD_LEN_64);
140 buf.put_u64(payload_len as u64);
141 }
142
143 if let Some(mask) = self.mask {
145 buf.put_slice(&mask);
146 }
147
148 buf.put_slice(&self.payload);
150 }
151
152 pub fn parse(buf: &mut BytesMut) -> Result<Self> {
154 if buf.len() < 2 {
155 return Err(FrameError::InsufficientData {
156 needed: 2,
157 have: buf.len(),
158 }
159 .into());
160 }
161
162 let mut cursor = std::io::Cursor::new(&buf[..]);
163
164 let first_byte = cursor.get_u8();
166 let fin = (first_byte & FIN_BIT) != 0;
167 let rsv1 = (first_byte & RSV1_BIT) != 0;
168 let rsv2 = (first_byte & RSV2_BIT) != 0;
169 let rsv3 = (first_byte & RSV3_BIT) != 0;
170 let opcode = Opcode::from(first_byte & OPCODE_MASK)
171 .ok_or(FrameError::InvalidOpcode(first_byte & OPCODE_MASK))?;
172
173 let second_byte = cursor.get_u8();
175 let masked = (second_byte & MASK_BIT) != 0;
176 let mut payload_len = (second_byte & PAYLOAD_LEN_MASK) as usize;
177
178 if payload_len == 126 {
180 if buf.len() < 4 {
181 return Err(FrameError::InsufficientData {
182 needed: 4,
183 have: buf.len(),
184 }
185 .into());
186 }
187 payload_len = cursor.get_u16() as usize;
188 } else if payload_len == 127 {
189 if buf.len() < 10 {
190 return Err(FrameError::InsufficientData {
191 needed: 10,
192 have: buf.len(),
193 }
194 .into());
195 }
196 payload_len = cursor.get_u64() as usize;
197 }
198
199 let mask = if masked {
201 if buf.len() < cursor.position() as usize + 4 + payload_len {
202 return Err(FrameError::InsufficientData {
203 needed: cursor.position() as usize + 4 + payload_len,
204 have: buf.len(),
205 }
206 .into());
207 }
208 let mut mask = [0u8; 4];
209 cursor.copy_to_slice(&mut mask);
210 Some(mask)
211 } else {
212 None
213 };
214
215 if buf.len() < cursor.position() as usize + payload_len {
217 return Err(FrameError::InsufficientData {
218 needed: cursor.position() as usize + payload_len,
219 have: buf.len(),
220 }
221 .into());
222 }
223
224 let mut payload = Bytes::copy_from_slice(
225 &buf[cursor.position() as usize..cursor.position() as usize + payload_len],
226 );
227
228 if let Some(mask) = mask {
230 payload = mask_bytes(&payload, &mask);
231 }
232
233 let frame_len = cursor.position() as usize + payload_len;
235 buf.advance(frame_len);
236
237 if opcode.is_control() && !fin {
239 return Err(FrameError::FragmentedControlFrame.into());
240 }
241
242 if rsv1 || rsv2 || rsv3 {
243 return Err(FrameError::ReservedBitsSet.into());
244 }
245
246 Ok(Frame {
247 fin,
248 rsv: [rsv1, rsv2, rsv3],
249 opcode,
250 masked,
251 mask,
252 payload,
253 })
254 }
255
256 pub fn kind(&self) -> FrameKind {
258 match self.opcode {
259 Opcode::Text => FrameKind::Text,
260 Opcode::Binary => FrameKind::Binary,
261 Opcode::Close => FrameKind::Close,
262 Opcode::Ping => FrameKind::Ping,
263 Opcode::Pong => FrameKind::Pong,
264 Opcode::Continuation => FrameKind::Continuation,
265 _ => FrameKind::Reserved,
266 }
267 }
268
269 pub fn payload_len(&self) -> usize {
271 self.payload.len()
272 }
273
274 pub fn is_control(&self) -> bool {
276 self.opcode.is_control()
277 }
278
279 pub fn is_data(&self) -> bool {
281 self.opcode.is_data()
282 }
283
284 pub fn is_final(&self) -> bool {
286 self.fin
287 }
288}
289
290#[derive(Debug, Clone, Copy, PartialEq, Eq)]
292pub enum FrameKind {
293 Text,
295 Binary,
297 Close,
299 Ping,
301 Pong,
303 Continuation,
305 Reserved,
307}
308
309fn mask_bytes(data: &[u8], mask: &[u8; 4]) -> Bytes {
311 let mut masked = BytesMut::with_capacity(data.len());
312 for (i, &byte) in data.iter().enumerate() {
313 masked.put_u8(byte ^ mask[i % 4]);
314 }
315 masked.freeze()
316}
317
318#[derive(Debug, Default)]
320pub struct FrameParser {
321 buffer: BytesMut,
323 expected_size: Option<usize>,
325}
326
327impl FrameParser {
328 pub fn new() -> Self {
330 Self::default()
331 }
332
333 pub fn feed(&mut self, data: &[u8]) -> Vec<Result<Frame>> {
335 self.buffer.extend_from_slice(data);
336 self.extract_frames()
337 }
338
339 fn extract_frames(&mut self) -> Vec<Result<Frame>> {
341 let mut frames = Vec::new();
342
343 while let Some(frame) = self.try_parse_frame() {
344 match frame {
345 Ok(f) => frames.push(Ok(f)),
346 Err(e) => {
347 frames.push(Err(e));
348 break;
349 }
350 }
351 }
352
353 frames
354 }
355
356 fn try_parse_frame(&mut self) -> Option<Result<Frame>> {
358 let mut buf = self.buffer.clone();
359
360 match Frame::parse(&mut buf) {
361 Ok(frame) => {
362 let parsed_len = self.buffer.len() - buf.len();
364 self.buffer.advance(parsed_len);
365 Some(Ok(frame))
366 }
367 Err(Error::Frame(FrameError::InsufficientData { .. })) => {
368 None
370 }
371 Err(e) => {
372 self.buffer.clear();
374 Some(Err(e))
375 }
376 }
377 }
378
379 pub fn buffered_bytes(&self) -> usize {
381 self.buffer.len()
382 }
383
384 pub fn clear(&mut self) {
386 self.buffer.clear();
387 self.expected_size = None;
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 #[test]
396 fn test_text_frame_serialization() {
397 let frame = Frame::text("hello");
398 let bytes = frame.to_bytes();
399
400 assert_eq!(bytes[0], 0x81); assert_eq!(bytes[1], 0x05); assert_eq!(&bytes[2..], b"hello");
403 }
404
405 #[test]
406 fn test_masked_frame() {
407 let frame = Frame::text("hello").mask(true);
408 let bytes = frame.to_bytes();
409
410 assert_eq!(bytes[1] & 0x80, 0x80); assert_eq!(bytes.len(), 2 + 4 + 5); }
413
414 #[test]
415 fn test_frame_parsing() {
416 let original = Frame::text("hello");
417 let bytes = original.to_bytes();
418 let mut buf = BytesMut::from(&bytes[..]);
419
420 let parsed = Frame::parse(&mut buf).unwrap();
421 assert_eq!(parsed.kind(), FrameKind::Text);
422 assert_eq!(parsed.payload, "hello");
423 assert!(buf.is_empty());
424 }
425
426 #[test]
427 fn test_large_frame() {
428 let payload = vec![0u8; 65536]; let frame = Frame::binary(payload.clone());
430 let bytes = frame.to_bytes();
431
432 assert_eq!(bytes[1], 127); assert_eq!(bytes[2..10], (65536u64).to_be_bytes());
434 }
435
436 #[test]
437 fn test_close_frame() {
438 let frame = Frame::close(Some(1000), Some("Goodbye"));
439 let bytes = frame.to_bytes();
440
441 assert_eq!(bytes[0], 0x88); assert_eq!(bytes[1], 0x09); assert_eq!(&bytes[2..4], 1000u16.to_be_bytes());
444 assert_eq!(&bytes[4..], b"Goodbye");
445 assert_eq!(bytes.len(), 11); }
447
448 #[test]
449 fn test_frame_parser() {
450 let mut parser = FrameParser::new();
451
452 let frame1 = Frame::text("frame1");
453 let frame2 = Frame::ping("ping");
454
455 let bytes1 = frame1.to_bytes();
456 let bytes2 = frame2.to_bytes();
457
458 let frames = parser.feed(&bytes1[..5]);
460 assert_eq!(frames.len(), 0); let frames = parser.feed(&bytes1[5..]);
464 assert_eq!(frames.len(), 1);
465 assert!(frames[0].as_ref().unwrap().is_data());
466
467 let frames = parser.feed(&bytes2);
469 assert_eq!(frames.len(), 1);
470 assert!(frames[0].as_ref().unwrap().is_control());
471 }
472}