1use crate::{message::Message, WsGonzaleError, WsGonzaleResult};
2
3#[inline(always)]
5pub fn get_buffer(message: Message) -> Vec<u8> {
6 let mut buffer: Vec<u8> = Vec::new();
7 buffer.push(129);
8 let s = match message {
9 Message::Text(s) => s,
10 _ => "".to_owned(),
11 };
12 match s.len() as u64 {
13 size @ 0..=125 => {
14 buffer.push(size as u8);
15 }
16 size if size > u32::MAX as u64 => {
17 let bytes: [u8; 8] = (size as u64).to_be_bytes();
18 buffer.push(127);
19 buffer.extend_from_slice(&bytes);
20 }
21 size if size <= u32::MAX as u64 => {
22 let bytes: [u8; 2] = (size as u16).to_be_bytes();
23 buffer.push(126);
24 buffer.extend_from_slice(&bytes);
25 }
26 _ => panic!("Don't know what to do here..."),
27 }
28 buffer.extend_from_slice(s.as_bytes());
29 buffer
30}
31#[inline(always)]
32pub fn mask_payload<'a, 'b>(incoming: &'a mut &'b mut [u8], mask: [u8; 4]) -> &'a [u8] {
34 let data: &'b mut [u8] = std::mem::take(incoming);
35 for i in 0..data.len() {
36 data[i] ^= mask[i % 4];
37 }
38 data
39}
40pub struct DataframeBuilder(Vec<u8>);
42#[derive(Debug)]
44pub struct Dataframe {
45 fin: bool,
46 rsv1: bool,
47 rsv2: bool,
48 rsv3: bool,
49 is_mask: bool,
50 opcode: u8,
51 payload_length: u64,
52 full_frame_length: u64,
53 masking_key: [u8; 4],
54 payload: Vec<u8>,
55}
56#[derive(PartialEq)]
57enum Opcode {
58 Continuation = 0,
59 Text = 1,
60 Close = 8,
61 Ping = 9,
62 Pong = 10,
63 Unknown,
64}
65impl From<u8> for Opcode {
66 fn from(v: u8) -> Opcode {
67 match v {
68 0 => Opcode::Continuation,
69 1 => Opcode::Text,
70 8 => Opcode::Close,
71 9 => Opcode::Ping,
72 10 => Opcode::Pong,
73 _ => Opcode::Unknown,
74 }
75 }
76}
77
78#[derive(Debug)]
79enum ExtraSize {
80 Zero(u8),
81 Two,
82 Eight,
83}
84mod frame_positions {
85 pub const FIN: u8 = 128;
87 pub const RSV1: u8 = 64;
88 pub const RSV2: u8 = 32;
89 pub const RSV3: u8 = 16;
90 pub const MASK_OPCODE: u8 = 0b00001111;
91 pub const IS_MASK: u8 = 128;
93 pub const MASK_PAYLOAD_LENGTH: u8 = 0b01111111;
94}
95impl DataframeBuilder {
96 pub fn new(buffer: Vec<u8>) -> WsGonzaleResult<Dataframe> {
97 DataframeBuilder(buffer).get_dataframe()
98 }
99 #[inline(always)]
100 fn is_fin(&self) -> bool {
101 self.0
102 .get(0)
103 .map(|frame| (frame & frame_positions::FIN) == frame_positions::FIN)
104 .unwrap_or(false)
105 }
106 #[inline(always)]
107 fn is_rsv1(&self) -> bool {
108 self.0
109 .get(0)
110 .map(|frame| (frame & frame_positions::RSV1) == frame_positions::RSV1)
111 .unwrap_or(false)
112 }
113 #[inline(always)]
114 fn is_rsv2(&self) -> bool {
115 self.0
116 .get(0)
117 .map(|frame| (frame & frame_positions::RSV2) == frame_positions::RSV2)
118 .unwrap_or(false)
119 }
120 #[inline(always)]
121 fn is_rsv3(&self) -> bool {
122 self.0
123 .get(0)
124 .map(|frame| (frame & frame_positions::RSV3) == frame_positions::RSV3)
125 .unwrap_or(false)
126 }
127 #[inline(always)]
128 fn get_opcode(&self) -> u8 {
130 self.0
132 .get(0)
133 .map(|frame| frame & frame_positions::MASK_OPCODE)
134 .unwrap_or(8)
135 }
136 #[inline(always)]
137 fn is_mask(&self) -> bool {
138 self.0
139 .get(1)
140 .map(|frame| (frame & frame_positions::IS_MASK) == frame_positions::IS_MASK)
141 .unwrap_or(false)
142 }
143 #[inline(always)]
145 fn get_short_payload_length(&self) -> u8 {
146 self.0
147 .get(1)
148 .map(|frame| frame & frame_positions::MASK_PAYLOAD_LENGTH)
149 .unwrap_or(0)
150 }
151 #[inline(always)]
152 fn get_extra_payload_bytes(&self) -> WsGonzaleResult<ExtraSize> {
153 let result = match self.get_short_payload_length() {
154 size @ 0..=125 => ExtraSize::Zero(size),
155 126 => ExtraSize::Two,
156 127 => ExtraSize::Eight,
157 _ => unreachable!("Max payload for a dataframe in WS spec is 127"),
158 };
159 Ok(result)
160 }
161 #[inline(always)]
162 fn get_payload_length(&self) -> WsGonzaleResult<u64> {
163 let slice = self.0.as_slice();
164 let result = match self.get_extra_payload_bytes()? {
165 ExtraSize::Zero(size) => size as u64,
166 ExtraSize::Two => match slice {
167 [_, _, first, second, ..] if slice.len() > 4 => {
168 u32::from_be_bytes([0, 0, *first, *second]) as u64
169 }
170 _ => return Err(WsGonzaleError::Unknown),
171 },
172 ExtraSize::Eight => match slice {
173 [_, _, first, second, third, fourth, fifth, sixth, seventh, eighth, ..]
174 if slice.len() > 8 =>
175 {
176 u64::from_be_bytes([
177 *first, *second, *third, *fourth, *fifth, *sixth, *seventh, *eighth,
178 ]) as u64
179 }
180 _ => return Err(WsGonzaleError::Unknown),
181 },
182 };
183
184 Ok(result)
185 }
186
187 fn get_payload_start_pos(&self) -> WsGonzaleResult<u64> {
188 let result = match self.get_extra_payload_bytes()? {
189 ExtraSize::Zero(_) => 6,
190 ExtraSize::Two => 8,
191 ExtraSize::Eight => 14,
192 };
193 Ok(result)
194 }
195 pub fn get_full_frame_length(&self) -> WsGonzaleResult<u64> {
196 let size = self.get_payload_start_pos()? + self.get_payload_length()?;
197
198 Ok(size)
199 }
200 #[inline(always)]
201 fn get_masking_key_start(&self) -> WsGonzaleResult<u8> {
202 let result = match self.get_extra_payload_bytes()? {
203 ExtraSize::Zero(_) => 0,
204 ExtraSize::Two => 2,
205 ExtraSize::Eight => 8,
206 };
207 Ok(result)
208 }
209 #[inline(always)]
210 fn get_masking_key(&self) -> WsGonzaleResult<[u8; 4]> {
211 let start = 2 + self.get_masking_key_start()? as usize;
212 let end = start + 4;
213 if self.is_mask() && self.0.len() >= end {
214 let mut buffer: [u8; 4] = [0; 4];
215 buffer.copy_from_slice(&self.0[start..end]);
216 Ok(buffer)
217 } else {
218 Ok([0, 0, 0, 0])
220 }
221 }
222 #[inline(always)]
223 fn get_payload(mut self) -> WsGonzaleResult<Vec<u8>> {
224 let start_payload = self.get_payload_start_pos()? as usize;
225 let is_mask = self.is_mask();
226 let masking_key = self.get_masking_key()?;
227 let payload_length = self.get_payload_length()? as usize;
228
229 if Opcode::from(self.get_opcode()) == Opcode::Close {
230 return Err(WsGonzaleError::ConnectionClosed);
232 }
233 if start_payload > self.0.len() {
234 return Err(WsGonzaleError::InvalidPayload);
235 }
236 self.0.drain(0..start_payload);
238 let mut data = self.0.into_iter().take(payload_length).collect::<Vec<u8>>();
239 if is_mask {
240 mask_payload(&mut &mut *data, masking_key);
241 }
242 Ok(data)
243 }
244 #[inline(always)]
245 fn get_dataframe(self) -> WsGonzaleResult<Dataframe> {
246 let result = Dataframe {
247 fin: self.is_fin(),
248 rsv1: self.is_rsv1(),
249 rsv2: self.is_rsv2(),
250 rsv3: self.is_rsv3(),
251 is_mask: self.is_mask(),
252 opcode: self.get_opcode(),
253 payload_length: self.get_payload_length()?,
254 full_frame_length: self.get_full_frame_length()?,
255 masking_key: self.get_masking_key()?,
256 payload: self.get_payload()?,
257 };
258 Ok(result)
259 }
260}
261impl Dataframe {
262 #[inline(always)]
263 pub fn get_message(self) -> WsGonzaleResult<Message> {
264 let result = match self.opcode {
265 1 => Message::Text(
266 String::from_utf8_lossy(&self.get_payload())
267 .parse()
268 .map_err(|_| WsGonzaleError::InvalidPayload)?,
269 ),
270 8 => Message::Close,
271 _ => Message::Unknown,
272 };
273 Ok(result)
274 }
275 #[inline(always)]
276 pub fn is_fin(&self) -> bool {
277 self.fin
278 }
279 #[inline(always)]
280 pub fn is_rsv1(&self) -> bool {
281 self.rsv1
282 }
283 #[inline(always)]
284 pub fn is_rsv2(&self) -> bool {
285 self.rsv2
286 }
287 #[inline(always)]
288 pub fn is_rsv3(&self) -> bool {
289 self.rsv3
290 }
291 #[inline(always)]
292 pub fn get_opcode(&self) -> u8 {
293 self.opcode
294 }
295 #[inline(always)]
296 pub fn is_mask(&self) -> bool {
297 self.is_mask
298 }
299 #[inline(always)]
300 pub fn get_payload_length(&self) -> u64 {
301 self.payload_length
302 }
303 pub fn get_full_frame_length(&self) -> u64 {
304 self.full_frame_length
305 }
306 #[inline(always)]
307 pub fn get_payload(self) -> Vec<u8> {
308 self.payload
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315 use crate::message::Message;
316 #[test]
317 #[should_panic]
318 fn test_buffer_with_no_payload_or_masking_key_but_payload_length() {
319 let buffer: Vec<u8> = vec![
320 129, 129, ];
323 let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
324 dataframe.get_message().unwrap();
325 }
326 #[test]
327 fn test_buffer_with_no_payload_but_masking_key_and_payload_length() {
328 let buffer: Vec<u8> = vec![
329 129, 129, 0, 0, 0, 0,
332 ];
333 let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
334 dataframe.get_message().unwrap();
335 }
336 #[test]
337 fn test_buffer_with_no_payload_or_mask() {
338 let buffer: Vec<u8> = vec![
339 129, 0,
341 ];
342 let result = DataframeBuilder::new(buffer);
343 assert_eq!(result.err().unwrap(), WsGonzaleError::InvalidPayload);
344 }
345 #[test]
346 fn test_close_frame_from_client() {
347 let buffer: Vec<u8> = vec![
348 136, 128, ];
351 let result = DataframeBuilder::new(buffer);
352 assert_eq!(result.err().unwrap(), WsGonzaleError::ConnectionClosed);
353 }
354 #[test]
355 fn test_buffer_with_no_payload_with_masking_key() {
356 let buffer: Vec<u8> = vec![
357 129, 128, 0, 0, 0, 0,
360 ];
361 let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
362 dataframe.get_message().unwrap();
363 }
364 #[test]
365 fn test_buffer_hello_world() {
366 let str = "Hello World";
367 let buffer: Vec<u8> = vec![
368 129, 139, 90, 212, 118, 181, 18, 177, 26, 217, 53, 244, 33, 218, 40, 184, 18,
369 ];
370 let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
371 dbg!(&dataframe);
372 assert!(dataframe.is_fin());
373 assert!(dataframe.is_mask());
374 assert_eq!(
375 String::from_utf8(dataframe.get_payload().to_vec())
376 .unwrap()
377 .as_str(),
378 str
379 );
380 }
381 #[test]
382 fn test_payload_size() {
383 let s = (0..488376).map(|_| "a").collect::<String>();
384 let buffer = vec![129, 255, 0, 0, 0, 0, 0, 7, 115, 184, 105, 143, 80, 179];
385 let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
386 assert_eq!(dataframe.get_payload_length(), s.len() as u64);
387 }
388
389 #[test]
390 fn test_buffer_to_dataframe() {
391 let buffer: Vec<u8> = vec![
392 129, 139, 90, 212, 118, 181, 18, 177, 26, 217, 53, 244, 33, 218, 40, 184, 18,
393 ];
394 let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
395 dbg!(dataframe);
396 }
397 #[test]
398 fn test_buffer_126_length() {
399 let str = "xZHtBeHbpCWCTCozNw0GxAdQ8Qqqtex5Zje8FBaVQpxrigx92BpLYYiXZnAA70CdNslWvgdSMz0vfUggF8U8wrULZz7ns1tUi5BDWmxx0XS5LsBeyFuaCq4NDAvwbi";
400 let buffer: Vec<u8> = vec![
401 129, 254, 0, 126, 202, 250, 57, 41, 178, 160, 113, 93, 136, 159, 113, 75, 186, 185,
402 110, 106, 158, 185, 86, 83, 132, 141, 9, 110, 178, 187, 93, 120, 242, 171, 72, 88, 190,
403 159, 65, 28, 144, 144, 92, 17, 140, 184, 88, 127, 155, 138, 65, 91, 163, 157, 65, 16,
404 248, 184, 73, 101, 147, 163, 80, 113, 144, 148, 120, 104, 253, 202, 122, 77, 132, 137,
405 85, 126, 188, 157, 93, 122, 135, 128, 9, 95, 172, 175, 94, 78, 140, 194, 108, 17, 189,
406 136, 108, 101, 144, 128, 14, 71, 185, 203, 77, 124, 163, 207, 123, 109, 157, 151, 65,
407 81, 250, 162, 106, 28, 134, 137, 123, 76, 179, 188, 76, 72, 137, 139, 13, 103, 142,
408 187, 79, 94, 168, 147,
409 ];
410 let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
411 let message = dataframe.get_message().unwrap();
412 assert_eq!(message, Message::Text(str.to_string()));
413 }
414 #[test]
415 fn test_buffer_126_overflow_length() {
416 let str = "xZHtBeHbpCWCTCozNw0GxAdQ8Qqqtex5Zje8FBaVQpxrigx92BpLYYiXZnAA70CdNslWvgdSMz0vfUggF8U8wrULZz7ns1tUi5BDWmxx0XS5LsBeyFuaCq4NDAvwbi";
417 let buffer: Vec<u8> = vec![
418 129, 254, 0, 126, 202, 250, 57, 41, 178, 160, 113, 93, 136, 159, 113, 75, 186, 185,
419 110, 106, 158, 185, 86, 83, 132, 141, 9, 110, 178, 187, 93, 120, 242, 171, 72, 88, 190,
420 159, 65, 28, 144, 144, 92, 17, 140, 184, 88, 127, 155, 138, 65, 91, 163, 157, 65, 16,
421 248, 184, 73, 101, 147, 163, 80, 113, 144, 148, 120, 104, 253, 202, 122, 77, 132, 137,
422 85, 126, 188, 157, 93, 122, 135, 128, 9, 95, 172, 175, 94, 78, 140, 194, 108, 17, 189,
423 136, 108, 101, 144, 128, 14, 71, 185, 203, 77, 124, 163, 207, 123, 109, 157, 151, 65,
424 81, 250, 162, 106, 28, 134, 137, 123, 76, 179, 188, 76, 72, 137, 139, 13, 103, 142,
425 187, 79, 94, 168, 147, 0, 0, 0, 0,
426 ];
427 let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
428 let message = dataframe.get_message().unwrap();
429 assert_eq!(message, Message::Text(str.to_string()));
430 }
431 #[test]
432 fn test_buffer_127_length() {
433 let str = "xZHtBeHbpCWCTCozNw0GxAdQ8Qqqtex5Zje8FBaVQpxrigx92BpLYYiXZnAA70CdNslWvgdSMz0vfUggF8U8wrULZz7ns1tUi5BDWmxx0XS5LsBeyFuaCq4NDAvwbia";
434 let buffer: Vec<u8> = vec![
435 129, 254, 0, 127, 238, 233, 37, 50, 150, 179, 109, 70, 172, 140, 109, 80, 158, 170,
436 114, 113, 186, 170, 74, 72, 160, 158, 21, 117, 150, 168, 65, 99, 214, 184, 84, 67, 154,
437 140, 93, 7, 180, 131, 64, 10, 168, 171, 68, 100, 191, 153, 93, 64, 135, 142, 93, 11,
438 220, 171, 85, 126, 183, 176, 76, 106, 180, 135, 100, 115, 217, 217, 102, 86, 160, 154,
439 73, 101, 152, 142, 65, 97, 163, 147, 21, 68, 136, 188, 66, 85, 168, 209, 112, 10, 153,
440 155, 112, 126, 180, 147, 18, 92, 157, 216, 81, 103, 135, 220, 103, 118, 185, 132, 93,
441 74, 222, 177, 118, 7, 162, 154, 103, 87, 151, 175, 80, 83, 173, 152, 17, 124, 170, 168,
442 83, 69, 140, 128, 68,
443 ];
444 let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
445 let message = dataframe.get_message().unwrap();
446 assert_eq!(message, Message::Text(str.to_string()));
447 }
448 #[test]
449 fn test_buffer_127_overflow_length() {
450 let str = "xZHtBeHbpCWCTCozNw0GxAdQ8Qqqtex5Zje8FBaVQpxrigx92BpLYYiXZnAA70CdNslWvgdSMz0vfUggF8U8wrULZz7ns1tUi5BDWmxx0XS5LsBeyFuaCq4NDAvwbia";
451 let buffer: Vec<u8> = vec![
452 129, 254, 0, 127, 238, 233, 37, 50, 150, 179, 109, 70, 172, 140, 109, 80, 158, 170,
453 114, 113, 186, 170, 74, 72, 160, 158, 21, 117, 150, 168, 65, 99, 214, 184, 84, 67, 154,
454 140, 93, 7, 180, 131, 64, 10, 168, 171, 68, 100, 191, 153, 93, 64, 135, 142, 93, 11,
455 220, 171, 85, 126, 183, 176, 76, 106, 180, 135, 100, 115, 217, 217, 102, 86, 160, 154,
456 73, 101, 152, 142, 65, 97, 163, 147, 21, 68, 136, 188, 66, 85, 168, 209, 112, 10, 153,
457 155, 112, 126, 180, 147, 18, 92, 157, 216, 81, 103, 135, 220, 103, 118, 185, 132, 93,
458 74, 222, 177, 118, 7, 162, 154, 103, 87, 151, 175, 80, 83, 173, 152, 17, 124, 170, 168,
459 83, 69, 140, 128, 68, 0, 0, 0, 0,
460 ];
461 let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
462 let message = dataframe.get_message().unwrap();
463 assert_eq!(message, Message::Text(str.to_string()));
464 }
465 #[test]
466 fn test_buffer_large() {
467 let str = "asdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadad";
468 let buffer: Vec<u8> = vec![
469 129, 254, 0, 152, 156, 22, 133, 192, 253, 101, 225, 179, 253, 114, 228, 179, 248, 119,
470 246, 164, 253, 114, 246, 161, 248, 119, 225, 161, 239, 114, 246, 161, 248, 119, 246,
471 164, 253, 101, 225, 161, 248, 101, 228, 164, 253, 114, 228, 179, 248, 101, 228, 164,
472 253, 101, 225, 161, 239, 114, 228, 164, 239, 119, 225, 161, 248, 119, 246, 164, 239,
473 119, 225, 161, 239, 114, 228, 179, 248, 119, 225, 179, 253, 114, 228, 164, 253, 101,
474 225, 179, 253, 114, 228, 179, 248, 119, 246, 164, 253, 114, 246, 161, 248, 119, 225,
475 161, 239, 114, 246, 161, 248, 119, 246, 164, 253, 101, 225, 161, 248, 101, 228, 164,
476 253, 114, 228, 179, 248, 101, 228, 164, 253, 101, 225, 161, 239, 114, 228, 164, 239,
477 119, 225, 161, 248, 119, 246, 164, 239, 119, 225, 161, 239, 114, 228, 179, 248, 119,
478 225, 179, 253, 114, 228, 164,
479 ];
480 let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
481 let message = dataframe.get_message().unwrap();
482 assert_eq!(message, Message::Text(str.to_string()));
483 }
484 #[test]
485 fn test_buffer_overflow_large() {
486 let str = "asdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadadasdsadasdasdadsadad";
487 let buffer: Vec<u8> = vec![
488 129, 254, 0, 152, 156, 22, 133, 192, 253, 101, 225, 179, 253, 114, 228, 179, 248, 119,
489 246, 164, 253, 114, 246, 161, 248, 119, 225, 161, 239, 114, 246, 161, 248, 119, 246,
490 164, 253, 101, 225, 161, 248, 101, 228, 164, 253, 114, 228, 179, 248, 101, 228, 164,
491 253, 101, 225, 161, 239, 114, 228, 164, 239, 119, 225, 161, 248, 119, 246, 164, 239,
492 119, 225, 161, 239, 114, 228, 179, 248, 119, 225, 179, 253, 114, 228, 164, 253, 101,
493 225, 179, 253, 114, 228, 179, 248, 119, 246, 164, 253, 114, 246, 161, 248, 119, 225,
494 161, 239, 114, 246, 161, 248, 119, 246, 164, 253, 101, 225, 161, 248, 101, 228, 164,
495 253, 114, 228, 179, 248, 101, 228, 164, 253, 101, 225, 161, 239, 114, 228, 164, 239,
496 119, 225, 161, 248, 119, 246, 164, 239, 119, 225, 161, 239, 114, 228, 179, 248, 119,
497 225, 179, 253, 114, 228, 164, 0, 0, 0, 0,
498 ];
499 let dataframe: Dataframe = DataframeBuilder::new(buffer).unwrap();
500 let message = dataframe.get_message().unwrap();
501 assert_eq!(message, Message::Text(str.to_string()));
502 }
503}