ws_gonzale/
dataframe.rs

1use crate::{message::Message, WsGonzaleError, WsGonzaleResult};
2
3/// Converts a [`Message`] to a `Vec<u8>`
4#[inline(always)]
5pub fn get_buffer(message: Message) -> Vec<u8> {
6    let mut buffer: Vec<u8> = Vec::new();
7    buffer.push(129);
8    let s = match message {
9        Message::Text(s) => s,
10        _ => "".to_owned(),
11    };
12    match s.len() as u64 {
13        size @ 0..=125 => {
14            buffer.push(size as u8);
15        }
16        size if size > u32::MAX as u64 => {
17            let bytes: [u8; 8] = (size as u64).to_be_bytes();
18            buffer.push(127);
19            buffer.extend_from_slice(&bytes);
20        }
21        size if size <= u32::MAX as u64 => {
22            let bytes: [u8; 2] = (size as u16).to_be_bytes();
23            buffer.push(126);
24            buffer.extend_from_slice(&bytes);
25        }
26        _ => panic!("Don't know what to do here..."),
27    }
28    buffer.extend_from_slice(s.as_bytes());
29    buffer
30}
31#[inline(always)]
32/// This masks the payload byte by byte and does a bitwise exclusive on index % 4 of mask
33pub fn mask_payload<'a, 'b>(incoming: &'a mut &'b mut [u8], mask: [u8; 4]) -> &'a [u8] {
34    let data: &'b mut [u8] = std::mem::take(incoming);
35    for i in 0..data.len() {
36        data[i] ^= mask[i % 4];
37    }
38    data
39}
40/// This helps build up our [`Dataframe`]
41pub struct DataframeBuilder(Vec<u8>);
42/// Dataframe created by [`DataframeBuilder`] that represents the content from a Websocket frame
43#[derive(Debug)]
44pub struct Dataframe {
45    fin: bool,
46    rsv1: bool,
47    rsv2: bool,
48    rsv3: bool,
49    is_mask: bool,
50    opcode: u8,
51    payload_length: u64,
52    full_frame_length: u64,
53    masking_key: [u8; 4],
54    payload: Vec<u8>,
55}
56#[derive(PartialEq)]
57enum Opcode {
58    Continuation = 0,
59    Text = 1,
60    Close = 8,
61    Ping = 9,
62    Pong = 10,
63    Unknown,
64}
65impl From<u8> for Opcode {
66    fn from(v: u8) -> Opcode {
67        match v {
68            0 => Opcode::Continuation,
69            1 => Opcode::Text,
70            8 => Opcode::Close,
71            9 => Opcode::Ping,
72            10 => Opcode::Pong,
73            _ => Opcode::Unknown,
74        }
75    }
76}
77
78#[derive(Debug)]
79enum ExtraSize {
80    Zero(u8),
81    Two,
82    Eight,
83}
84mod frame_positions {
85    // Frame one
86    pub const FIN: u8 = 128;
87    pub const RSV1: u8 = 64;
88    pub const RSV2: u8 = 32;
89    pub const RSV3: u8 = 16;
90    pub const MASK_OPCODE: u8 = 0b00001111;
91    // Frame two
92    pub const IS_MASK: u8 = 128;
93    pub const MASK_PAYLOAD_LENGTH: u8 = 0b01111111;
94}
95impl DataframeBuilder {
96    pub fn new(buffer: Vec<u8>) -> WsGonzaleResult<Dataframe> {
97        DataframeBuilder(buffer).get_dataframe()
98    }
99    #[inline(always)]
100    fn is_fin(&self) -> bool {
101        self.0
102            .get(0)
103            .map(|frame| (frame & frame_positions::FIN) == frame_positions::FIN)
104            .unwrap_or(false)
105    }
106    #[inline(always)]
107    fn is_rsv1(&self) -> bool {
108        self.0
109            .get(0)
110            .map(|frame| (frame & frame_positions::RSV1) == frame_positions::RSV1)
111            .unwrap_or(false)
112    }
113    #[inline(always)]
114    fn is_rsv2(&self) -> bool {
115        self.0
116            .get(0)
117            .map(|frame| (frame & frame_positions::RSV2) == frame_positions::RSV2)
118            .unwrap_or(false)
119    }
120    #[inline(always)]
121    fn is_rsv3(&self) -> bool {
122        self.0
123            .get(0)
124            .map(|frame| (frame & frame_positions::RSV3) == frame_positions::RSV3)
125            .unwrap_or(false)
126    }
127    #[inline(always)]
128    /// Get the last four bits in one byte in first frame
129    fn get_opcode(&self) -> u8 {
130        // default to close
131        self.0
132            .get(0)
133            .map(|frame| frame & frame_positions::MASK_OPCODE)
134            .unwrap_or(8)
135    }
136    #[inline(always)]
137    fn is_mask(&self) -> bool {
138        self.0
139            .get(1)
140            .map(|frame| (frame & frame_positions::IS_MASK) == frame_positions::IS_MASK)
141            .unwrap_or(false)
142    }
143    /// Get the last seven bits in the byte in the second frame
144    #[inline(always)]
145    fn get_short_payload_length(&self) -> u8 {
146        self.0
147            .get(1)
148            .map(|frame| frame & frame_positions::MASK_PAYLOAD_LENGTH)
149            .unwrap_or(0)
150    }
151    #[inline(always)]
152    fn get_extra_payload_bytes(&self) -> WsGonzaleResult<ExtraSize> {
153        let result = match self.get_short_payload_length() {
154            size @ 0..=125 => ExtraSize::Zero(size),
155            126 => ExtraSize::Two,
156            127 => ExtraSize::Eight,
157            _ => unreachable!("Max payload for a dataframe in WS spec is 127"),
158        };
159        Ok(result)
160    }
161    #[inline(always)]
162    fn get_payload_length(&self) -> WsGonzaleResult<u64> {
163        let slice = self.0.as_slice();
164        let result = match self.get_extra_payload_bytes()? {
165            ExtraSize::Zero(size) => size as u64,
166            ExtraSize::Two => match slice {
167                [_, _, first, second, ..] if slice.len() > 4 => {
168                    u32::from_be_bytes([0, 0, *first, *second]) as u64
169                }
170                _ => return Err(WsGonzaleError::Unknown),
171            },
172            ExtraSize::Eight => match slice {
173                [_, _, first, second, third, fourth, fifth, sixth, seventh, eighth, ..]
174                    if slice.len() > 8 =>
175                {
176                    u64::from_be_bytes([
177                        *first, *second, *third, *fourth, *fifth, *sixth, *seventh, *eighth,
178                    ]) as u64
179                }
180                _ => return Err(WsGonzaleError::Unknown),
181            },
182        };
183
184        Ok(result)
185    }
186
187    fn get_payload_start_pos(&self) -> WsGonzaleResult<u64> {
188        let result = match self.get_extra_payload_bytes()? {
189            ExtraSize::Zero(_) => 6,
190            ExtraSize::Two => 8,
191            ExtraSize::Eight => 14,
192        };
193        Ok(result)
194    }
195    pub fn get_full_frame_length(&self) -> WsGonzaleResult<u64> {
196        let size = self.get_payload_start_pos()? + self.get_payload_length()?;
197
198        Ok(size)
199    }
200    #[inline(always)]
201    fn get_masking_key_start(&self) -> WsGonzaleResult<u8> {
202        let result = match self.get_extra_payload_bytes()? {
203            ExtraSize::Zero(_) => 0,
204            ExtraSize::Two => 2,
205            ExtraSize::Eight => 8,
206        };
207        Ok(result)
208    }
209    #[inline(always)]
210    fn get_masking_key(&self) -> WsGonzaleResult<[u8; 4]> {
211        let start = 2 + self.get_masking_key_start()? as usize;
212        let end = start + 4;
213        if self.is_mask() && self.0.len() >= end {
214            let mut buffer: [u8; 4] = [0; 4];
215            buffer.copy_from_slice(&self.0[start..end]);
216            Ok(buffer)
217        } else {
218            // masking key [0, 0, 0, 0] is ok because 1 ^ 0 == 1, 0 ^ 0 == 0
219            Ok([0, 0, 0, 0])
220        }
221    }
222    #[inline(always)]
223    fn get_payload(mut self) -> WsGonzaleResult<Vec<u8>> {
224        let start_payload = self.get_payload_start_pos()? as usize;
225        let is_mask = self.is_mask();
226        let masking_key = self.get_masking_key()?;
227        let payload_length = self.get_payload_length()? as usize;
228
229        if Opcode::from(self.get_opcode()) == Opcode::Close {
230            // TODO: Add support for reason message for closing
231            return Err(WsGonzaleError::ConnectionClosed);
232        }
233        if start_payload > self.0.len() {
234            return Err(WsGonzaleError::InvalidPayload);
235        }
236        // Remove first {start_payload}:th bytes from dataframe payload
237        self.0.drain(0..start_payload);
238        let mut data = self.0.into_iter().take(payload_length).collect::<Vec<u8>>();
239        if is_mask {
240            mask_payload(&mut &mut *data, masking_key);
241        }
242        Ok(data)
243    }
244    #[inline(always)]
245    fn get_dataframe(self) -> WsGonzaleResult<Dataframe> {
246        let result = Dataframe {
247            fin: self.is_fin(),
248            rsv1: self.is_rsv1(),
249            rsv2: self.is_rsv2(),
250            rsv3: self.is_rsv3(),
251            is_mask: self.is_mask(),
252            opcode: self.get_opcode(),
253            payload_length: self.get_payload_length()?,
254            full_frame_length: self.get_full_frame_length()?,
255            masking_key: self.get_masking_key()?,
256            payload: self.get_payload()?,
257        };
258        Ok(result)
259    }
260}
261impl Dataframe {
262    #[inline(always)]
263    pub fn get_message(self) -> WsGonzaleResult<Message> {
264        let result = match self.opcode {
265            1 => Message::Text(
266                String::from_utf8_lossy(&self.get_payload())
267                    .parse()
268                    .map_err(|_| WsGonzaleError::InvalidPayload)?,
269            ),
270            8 => Message::Close,
271            _ => Message::Unknown,
272        };
273        Ok(result)
274    }
275    #[inline(always)]
276    pub fn is_fin(&self) -> bool {
277        self.fin
278    }
279    #[inline(always)]
280    pub fn is_rsv1(&self) -> bool {
281        self.rsv1
282    }
283    #[inline(always)]
284    pub fn is_rsv2(&self) -> bool {
285        self.rsv2
286    }
287    #[inline(always)]
288    pub fn is_rsv3(&self) -> bool {
289        self.rsv3
290    }
291    #[inline(always)]
292    pub fn get_opcode(&self) -> u8 {
293        self.opcode
294    }
295    #[inline(always)]
296    pub fn is_mask(&self) -> bool {
297        self.is_mask
298    }
299    #[inline(always)]
300    pub fn get_payload_length(&self) -> u64 {
301        self.payload_length
302    }
303    pub fn get_full_frame_length(&self) -> u64 {
304        self.full_frame_length
305    }
306    #[inline(always)]
307    pub fn get_payload(self) -> Vec<u8> {
308        self.payload
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use crate::message::Message;
316    #[test]
317    #[should_panic]
318    fn test_buffer_with_no_payload_or_masking_key_but_payload_length() {
319        let buffer: Vec<u8> = vec![
320            129, // FIN(128) + Opcode(1)
321            129, // MASK(128) + PayloadLength(1)
322        ];
323        let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
324        dataframe.get_message().unwrap();
325    }
326    #[test]
327    fn test_buffer_with_no_payload_but_masking_key_and_payload_length() {
328        let buffer: Vec<u8> = vec![
329            129, // FIN(128) + Opcode(1)
330            129, // MASK(128) + PayloadLength(1)
331            0, 0, 0, 0,
332        ];
333        let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
334        dataframe.get_message().unwrap();
335    }
336    #[test]
337    fn test_buffer_with_no_payload_or_mask() {
338        let buffer: Vec<u8> = vec![
339            129, // FIN(128) + Opcode(1)
340            0,
341        ];
342        let result = DataframeBuilder::new(buffer);
343        assert_eq!(result.err().unwrap(), WsGonzaleError::InvalidPayload);
344    }
345    #[test]
346    fn test_close_frame_from_client() {
347        let buffer: Vec<u8> = vec![
348            136, // FIN(128) + Opcode(8)
349            128, // MASK(128)
350        ];
351        let result = DataframeBuilder::new(buffer);
352        assert_eq!(result.err().unwrap(), WsGonzaleError::ConnectionClosed);
353    }
354    #[test]
355    fn test_buffer_with_no_payload_with_masking_key() {
356        let buffer: Vec<u8> = vec![
357            129, // FIN(128) + Opcode(1)
358            128, // MASK(128)
359            0, 0, 0, 0,
360        ];
361        let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
362        dataframe.get_message().unwrap();
363    }
364    #[test]
365    fn test_buffer_hello_world() {
366        let str = "Hello World";
367        let buffer: Vec<u8> = vec![
368            129, 139, 90, 212, 118, 181, 18, 177, 26, 217, 53, 244, 33, 218, 40, 184, 18,
369        ];
370        let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
371        dbg!(&dataframe);
372        assert!(dataframe.is_fin());
373        assert!(dataframe.is_mask());
374        assert_eq!(
375            String::from_utf8(dataframe.get_payload().to_vec())
376                .unwrap()
377                .as_str(),
378            str
379        );
380    }
381    #[test]
382    fn test_payload_size() {
383        let s = (0..488376).map(|_| "a").collect::<String>();
384        let buffer = vec![129, 255, 0, 0, 0, 0, 0, 7, 115, 184, 105, 143, 80, 179];
385        let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
386        assert_eq!(dataframe.get_payload_length(), s.len() as u64);
387    }
388
389    #[test]
390    fn test_buffer_to_dataframe() {
391        let buffer: Vec<u8> = vec![
392            129, 139, 90, 212, 118, 181, 18, 177, 26, 217, 53, 244, 33, 218, 40, 184, 18,
393        ];
394        let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
395        dbg!(dataframe);
396    }
397    #[test]
398    fn test_buffer_126_length() {
399        let str = "xZHtBeHbpCWCTCozNw0GxAdQ8Qqqtex5Zje8FBaVQpxrigx92BpLYYiXZnAA70CdNslWvgdSMz0vfUggF8U8wrULZz7ns1tUi5BDWmxx0XS5LsBeyFuaCq4NDAvwbi";
400        let buffer: Vec<u8> = vec![
401            129, 254, 0, 126, 202, 250, 57, 41, 178, 160, 113, 93, 136, 159, 113, 75, 186, 185,
402            110, 106, 158, 185, 86, 83, 132, 141, 9, 110, 178, 187, 93, 120, 242, 171, 72, 88, 190,
403            159, 65, 28, 144, 144, 92, 17, 140, 184, 88, 127, 155, 138, 65, 91, 163, 157, 65, 16,
404            248, 184, 73, 101, 147, 163, 80, 113, 144, 148, 120, 104, 253, 202, 122, 77, 132, 137,
405            85, 126, 188, 157, 93, 122, 135, 128, 9, 95, 172, 175, 94, 78, 140, 194, 108, 17, 189,
406            136, 108, 101, 144, 128, 14, 71, 185, 203, 77, 124, 163, 207, 123, 109, 157, 151, 65,
407            81, 250, 162, 106, 28, 134, 137, 123, 76, 179, 188, 76, 72, 137, 139, 13, 103, 142,
408            187, 79, 94, 168, 147,
409        ];
410        let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
411        let message = dataframe.get_message().unwrap();
412        assert_eq!(message, Message::Text(str.to_string()));
413    }
414    #[test]
415    fn test_buffer_126_overflow_length() {
416        let str = "xZHtBeHbpCWCTCozNw0GxAdQ8Qqqtex5Zje8FBaVQpxrigx92BpLYYiXZnAA70CdNslWvgdSMz0vfUggF8U8wrULZz7ns1tUi5BDWmxx0XS5LsBeyFuaCq4NDAvwbi";
417        let buffer: Vec<u8> = vec![
418            129, 254, 0, 126, 202, 250, 57, 41, 178, 160, 113, 93, 136, 159, 113, 75, 186, 185,
419            110, 106, 158, 185, 86, 83, 132, 141, 9, 110, 178, 187, 93, 120, 242, 171, 72, 88, 190,
420            159, 65, 28, 144, 144, 92, 17, 140, 184, 88, 127, 155, 138, 65, 91, 163, 157, 65, 16,
421            248, 184, 73, 101, 147, 163, 80, 113, 144, 148, 120, 104, 253, 202, 122, 77, 132, 137,
422            85, 126, 188, 157, 93, 122, 135, 128, 9, 95, 172, 175, 94, 78, 140, 194, 108, 17, 189,
423            136, 108, 101, 144, 128, 14, 71, 185, 203, 77, 124, 163, 207, 123, 109, 157, 151, 65,
424            81, 250, 162, 106, 28, 134, 137, 123, 76, 179, 188, 76, 72, 137, 139, 13, 103, 142,
425            187, 79, 94, 168, 147, 0, 0, 0, 0,
426        ];
427        let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
428        let message = dataframe.get_message().unwrap();
429        assert_eq!(message, Message::Text(str.to_string()));
430    }
431    #[test]
432    fn test_buffer_127_length() {
433        let str = "xZHtBeHbpCWCTCozNw0GxAdQ8Qqqtex5Zje8FBaVQpxrigx92BpLYYiXZnAA70CdNslWvgdSMz0vfUggF8U8wrULZz7ns1tUi5BDWmxx0XS5LsBeyFuaCq4NDAvwbia";
434        let buffer: Vec<u8> = vec![
435            129, 254, 0, 127, 238, 233, 37, 50, 150, 179, 109, 70, 172, 140, 109, 80, 158, 170,
436            114, 113, 186, 170, 74, 72, 160, 158, 21, 117, 150, 168, 65, 99, 214, 184, 84, 67, 154,
437            140, 93, 7, 180, 131, 64, 10, 168, 171, 68, 100, 191, 153, 93, 64, 135, 142, 93, 11,
438            220, 171, 85, 126, 183, 176, 76, 106, 180, 135, 100, 115, 217, 217, 102, 86, 160, 154,
439            73, 101, 152, 142, 65, 97, 163, 147, 21, 68, 136, 188, 66, 85, 168, 209, 112, 10, 153,
440            155, 112, 126, 180, 147, 18, 92, 157, 216, 81, 103, 135, 220, 103, 118, 185, 132, 93,
441            74, 222, 177, 118, 7, 162, 154, 103, 87, 151, 175, 80, 83, 173, 152, 17, 124, 170, 168,
442            83, 69, 140, 128, 68,
443        ];
444        let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
445        let message = dataframe.get_message().unwrap();
446        assert_eq!(message, Message::Text(str.to_string()));
447    }
448    #[test]
449    fn test_buffer_127_overflow_length() {
450        let str = "xZHtBeHbpCWCTCozNw0GxAdQ8Qqqtex5Zje8FBaVQpxrigx92BpLYYiXZnAA70CdNslWvgdSMz0vfUggF8U8wrULZz7ns1tUi5BDWmxx0XS5LsBeyFuaCq4NDAvwbia";
451        let buffer: Vec<u8> = vec![
452            129, 254, 0, 127, 238, 233, 37, 50, 150, 179, 109, 70, 172, 140, 109, 80, 158, 170,
453            114, 113, 186, 170, 74, 72, 160, 158, 21, 117, 150, 168, 65, 99, 214, 184, 84, 67, 154,
454            140, 93, 7, 180, 131, 64, 10, 168, 171, 68, 100, 191, 153, 93, 64, 135, 142, 93, 11,
455            220, 171, 85, 126, 183, 176, 76, 106, 180, 135, 100, 115, 217, 217, 102, 86, 160, 154,
456            73, 101, 152, 142, 65, 97, 163, 147, 21, 68, 136, 188, 66, 85, 168, 209, 112, 10, 153,
457            155, 112, 126, 180, 147, 18, 92, 157, 216, 81, 103, 135, 220, 103, 118, 185, 132, 93,
458            74, 222, 177, 118, 7, 162, 154, 103, 87, 151, 175, 80, 83, 173, 152, 17, 124, 170, 168,
459            83, 69, 140, 128, 68, 0, 0, 0, 0,
460        ];
461        let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
462        let message = dataframe.get_message().unwrap();
463        assert_eq!(message, Message::Text(str.to_string()));
464    }
465    #[test]
466    fn test_buffer_large() {
467        let str = "asdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadad";
468        let buffer: Vec<u8> = vec![
469            129, 254, 0, 152, 156, 22, 133, 192, 253, 101, 225, 179, 253, 114, 228, 179, 248, 119,
470            246, 164, 253, 114, 246, 161, 248, 119, 225, 161, 239, 114, 246, 161, 248, 119, 246,
471            164, 253, 101, 225, 161, 248, 101, 228, 164, 253, 114, 228, 179, 248, 101, 228, 164,
472            253, 101, 225, 161, 239, 114, 228, 164, 239, 119, 225, 161, 248, 119, 246, 164, 239,
473            119, 225, 161, 239, 114, 228, 179, 248, 119, 225, 179, 253, 114, 228, 164, 253, 101,
474            225, 179, 253, 114, 228, 179, 248, 119, 246, 164, 253, 114, 246, 161, 248, 119, 225,
475            161, 239, 114, 246, 161, 248, 119, 246, 164, 253, 101, 225, 161, 248, 101, 228, 164,
476            253, 114, 228, 179, 248, 101, 228, 164, 253, 101, 225, 161, 239, 114, 228, 164, 239,
477            119, 225, 161, 248, 119, 246, 164, 239, 119, 225, 161, 239, 114, 228, 179, 248, 119,
478            225, 179, 253, 114, 228, 164,
479        ];
480        let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
481        let message = dataframe.get_message().unwrap();
482        assert_eq!(message, Message::Text(str.to_string()));
483    }
484    #[test]
485    fn test_buffer_overflow_large() {
486        let str = "asdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadad";
487        let buffer: Vec<u8> = vec![
488            129, 254, 0, 152, 156, 22, 133, 192, 253, 101, 225, 179, 253, 114, 228, 179, 248, 119,
489            246, 164, 253, 114, 246, 161, 248, 119, 225, 161, 239, 114, 246, 161, 248, 119, 246,
490            164, 253, 101, 225, 161, 248, 101, 228, 164, 253, 114, 228, 179, 248, 101, 228, 164,
491            253, 101, 225, 161, 239, 114, 228, 164, 239, 119, 225, 161, 248, 119, 246, 164, 239,
492            119, 225, 161, 239, 114, 228, 179, 248, 119, 225, 179, 253, 114, 228, 164, 253, 101,
493            225, 179, 253, 114, 228, 179, 248, 119, 246, 164, 253, 114, 246, 161, 248, 119, 225,
494            161, 239, 114, 246, 161, 248, 119, 246, 164, 253, 101, 225, 161, 248, 101, 228, 164,
495            253, 114, 228, 179, 248, 101, 228, 164, 253, 101, 225, 161, 239, 114, 228, 164, 239,
496            119, 225, 161, 248, 119, 246, 164, 239, 119, 225, 161, 239, 114, 228, 179, 248, 119,
497            225, 179, 253, 114, 228, 164, 0, 0, 0, 0,
498        ];
499        let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
500        let message = dataframe.get_message().unwrap();
501        assert_eq!(message, Message::Text(str.to_string()));
502    }
503}