1use bytes::{BufMut, Buf, BytesMut, BigEndian};
5use frame::base::{Frame, OpCode};
6use slog::Logger;
7use std::io::{self, Cursor};
8use tokio_io::codec::{Decoder, Encoder};
9use util;
10use vatfluid::{Success, validate};
11
12const TWO_EXT: u8 = 126;
15const EIGHT_EXT: u8 = 127;
18
19#[derive(Debug, Clone)]
21pub enum DecodeState {
22 NONE,
24 HEADER,
26 LENGTH,
28 MASK,
30 FULL,
32}
33
34impl Default for DecodeState {
35 fn default() -> DecodeState {
36 DecodeState::NONE
37 }
38}
39
40#[derive(Clone, Debug, Default)]
44pub struct FrameCodec {
45 client: bool,
47 fin: bool,
49 rsv1: bool,
51 rsv2: bool,
53 rsv3: bool,
55 opcode: OpCode,
57 masked: bool,
59 length_code: u8,
61 payload_length: u64,
63 mask_key: u32,
65 extension_data: Option<Vec<u8>>,
67 application_data: Vec<u8>,
69 pos: usize,
71 state: DecodeState,
73 min_len: u64,
75 reserved_bits: u8,
77 stdout: Option<Logger>,
79 stderr: Option<Logger>,
81}
82
83impl FrameCodec {
84 pub fn set_client(&mut self, client: bool) -> &mut FrameCodec {
86 self.client = client;
87 self
88 }
89
90 pub fn set_reserved_bits(&mut self, reserved_bits: u8) -> &mut FrameCodec {
92 self.reserved_bits = reserved_bits;
93 self
94 }
95
96 pub fn stdout(&mut self, logger: Logger) -> &mut FrameCodec {
98 let stdout = logger.new(o!("codec" => "base"));
99 self.stdout = Some(stdout);
100 self
101 }
102
103 pub fn stderr(&mut self, logger: Logger) -> &mut FrameCodec {
105 let stderr = logger.new(o!("codec" => "base"));
106 self.stderr = Some(stderr);
107 self
108 }
109}
110
111fn apply_mask(buf: &mut [u8], mask: u32) -> Result<(), io::Error> {
113 let mut mask_buf = BytesMut::with_capacity(4);
114 mask_buf.put_u32::<BigEndian>(mask);
115 let iter = buf.iter_mut().zip(mask_buf.iter().cycle());
116 for (byte, &key) in iter {
117 *byte ^= key;
118 }
119 Ok(())
120}
121
122impl Decoder for FrameCodec {
123 type Item = Frame;
124 type Error = io::Error;
125
126 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
127 let buf_len = buf.len();
128 if buf_len == 0 {
129 return Ok(None);
130 }
131
132 self.min_len = 0;
133 loop {
134 match self.state {
135 DecodeState::NONE => {
136 self.min_len += 2;
137 if (buf_len as u64) < self.min_len {
139 return Ok(None);
140 }
141 let header_bytes = buf.split_to(2);
142 let header = &header_bytes;
143 let first = header[0];
144 let second = header[1];
145
146 self.fin = first & 0x80 != 0;
148 self.rsv1 = first & 0x40 != 0;
149 if self.rsv1 && (self.reserved_bits & 0x4 == 0) {
150 return Err(util::other("invalid rsv1 bit set"));
151 }
152
153 self.rsv2 = first & 0x20 != 0;
154 if self.rsv2 && (self.reserved_bits & 0x2 == 0) {
155 return Err(util::other("invalid rsv2 bit set"));
156 }
157
158 self.rsv3 = first & 0x10 != 0;
159 if self.rsv3 && (self.reserved_bits & 0x1 == 0) {
160 return Err(util::other("invalid rsv3 bit set"));
161 }
162
163 self.opcode = OpCode::from((first & 0x0F) as u8);
164 if self.opcode.is_invalid() {
165 return Err(util::other("invalid opcode set"));
166 }
167 if self.opcode.is_control() && !self.fin {
168 return Err(util::other("control frames must not be fragmented"));
169 }
170
171 self.masked = second & 0x80 != 0;
172 if !self.masked && !self.client {
173 return Err(util::other("all client frames must have a mask"));
174 }
175
176 self.length_code = (second & 0x7F) as u8;
177 self.state = DecodeState::HEADER;
178 }
179 DecodeState::HEADER => {
180 if self.length_code == TWO_EXT {
181 self.min_len += 2;
182 if (buf_len as u64) < self.min_len {
183 self.min_len -= 2;
184 return Ok(None);
185 }
186 let len = Cursor::new(buf.split_to(2)).get_u16::<BigEndian>();
187 self.payload_length = len as u64;
188 self.state = DecodeState::LENGTH;
189 } else if self.length_code == EIGHT_EXT {
190 self.min_len += 8;
191 if (buf_len as u64) < self.min_len {
192 self.min_len -= 8;
193 return Ok(None);
194 }
195 let len = Cursor::new(buf.split_to(8)).get_u64::<BigEndian>();
196 self.payload_length = len as u64;
197 self.state = DecodeState::LENGTH;
198 } else {
199 self.payload_length = self.length_code as u64;
200 self.state = DecodeState::LENGTH;
201 }
202 if self.payload_length > 125 && self.opcode.is_control() {
203 return Err(util::other("invalid control frame"));
204 }
205 }
206 DecodeState::LENGTH => {
207 if self.masked {
208 self.min_len += 4;
209 if (buf_len as u64) < self.min_len {
210 self.min_len -= 4;
211 return Ok(None);
212 }
213 let mask = Cursor::new(buf.split_to(4)).get_u32::<BigEndian>();
214 self.mask_key = mask;
215 self.state = DecodeState::MASK;
216 } else {
217 self.mask_key = 0;
218 self.state = DecodeState::MASK;
219 }
220 }
221 DecodeState::MASK => {
222 if self.payload_length > 0 {
223 let mask = self.mask_key;
224 let app_data_len = self.application_data.len();
225 if buf.is_empty() {
226 return Ok(None);
227 } else if ((buf.len() + app_data_len) as u64) < self.payload_length {
228 self.application_data.extend(buf.take());
229 if self.opcode == OpCode::Text {
230 apply_mask(&mut self.application_data, mask)?;
231 try_trace!(self.stdout, "validating from pos: {}", self.pos);
232 match validate(&self.application_data[self.pos..]) {
233 Ok(Success::Complete(pos)) => {
234 try_trace!(self.stdout, "complete: {}", pos);
235 self.pos += pos;
236 }
237 Ok(Success::Incomplete(_, pos)) => {
238 try_trace!(self.stdout, "incomplete: {}", pos);
239 self.pos += pos;
240 }
241 Err(e) => {
242 try_error!(self.stderr, "{}", e);
243 return Err(util::other("invalid utf-8 sequence"));
244 }
245 }
246 apply_mask(&mut self.application_data, mask)?;
247 }
248 return Ok(None);
249 } else {
250 #[cfg_attr(feature = "cargo-clippy", allow(cast_possible_truncation))]
251 let split_len = (self.payload_length as usize) - app_data_len;
252 self.application_data.extend(buf.split_to(split_len));
253 if self.masked {
254 apply_mask(&mut self.application_data, mask)?;
255 }
256 self.state = DecodeState::FULL;
257 }
258 } else {
259 self.state = DecodeState::FULL;
260 }
261 }
262 DecodeState::FULL => break,
263 }
264 }
265
266 Ok(Some(self.clone().into()))
267 }
268}
269
270impl Encoder for FrameCodec {
271 type Item = Frame;
272 type Error = io::Error;
273
274 fn encode(&mut self, msg: Self::Item, buf: &mut BytesMut) -> io::Result<()> {
275 let mut first_byte = 0_u8;
276
277 if msg.fin() {
278 first_byte |= 0x80;
279 }
280
281 if msg.rsv1() {
282 first_byte |= 0x40;
283 }
284
285 if msg.rsv2() {
286 first_byte |= 0x20;
287 }
288
289 if msg.rsv3() {
290 first_byte |= 0x10;
291 }
292
293 let opcode: u8 = msg.opcode().into();
294 first_byte |= opcode;
295 buf.put(first_byte);
296
297 let mut second_byte = 0_u8;
298
299 if msg.masked() {
300 second_byte |= 0x80;
301 }
302
303 let len = msg.payload_length();
304 if len < TWO_EXT as u64 {
305 #[cfg_attr(feature = "cargo-clippy", allow(cast_possible_truncation))]
306 let cast_len = len as u8;
307 second_byte |= cast_len;
308 buf.put(second_byte);
309 } else if len < 65536 {
310 second_byte |= TWO_EXT;
311 let mut len_buf = BytesMut::with_capacity(2);
312 #[cfg_attr(feature = "cargo-clippy", allow(cast_possible_truncation))]
313 let cast_len = len as u16;
314 len_buf.put_u16::<BigEndian>(cast_len);
315 buf.put(second_byte);
316 buf.extend(len_buf);
317 } else {
318 second_byte |= EIGHT_EXT;
319 let mut len_buf = BytesMut::with_capacity(8);
320 len_buf.put_u64::<BigEndian>(len);
321 buf.put(second_byte);
322 buf.extend(len_buf);
323 }
324
325 if msg.masked() {
326 let mut mask_buf = BytesMut::with_capacity(4);
327 mask_buf.put_u32::<BigEndian>(msg.mask());
328 buf.extend(mask_buf);
329 }
330
331 if !msg.application_data().is_empty() {
332 buf.extend(msg.application_data().clone());
333 }
334
335 Ok(())
336 }
337}
338
339impl From<FrameCodec> for Frame {
340 fn from(frame_codec: FrameCodec) -> Frame {
341 let mut frame: Frame = Default::default();
342 frame.set_fin(frame_codec.fin);
343 frame.set_rsv1(frame_codec.rsv1);
344 frame.set_rsv2(frame_codec.rsv2);
345 frame.set_rsv3(frame_codec.rsv3);
346 frame.set_masked(frame_codec.masked);
347 frame.set_opcode(frame_codec.opcode);
348 frame.set_mask(frame_codec.mask_key);
349 frame.set_payload_length(frame_codec.payload_length);
350 frame.set_application_data(frame_codec.application_data);
351 frame.set_extension_data(frame_codec.extension_data);
352 frame
353 }
354}
355
356#[cfg(test)]
357mod test {
358 use super::FrameCodec;
359 use bytes::BytesMut;
360 use frame::base::{Frame, OpCode};
361 use std::io;
362 use tokio_io::codec::Decoder;
363 use util;
364
365 #[cfg_attr(rustfmt, rustfmt_skip)]
367 const NO_MASK: [u8; 2] = [0x89, 0x00];
369 #[cfg_attr(rustfmt, rustfmt_skip)]
370 const CTRL_PAYLOAD_LEN : [u8; 9] = [0x89, 0xFE, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
372
373 #[cfg_attr(rustfmt, rustfmt_skip)]
375 const PARTIAL_HEADER: [u8; 1] = [0x89];
377 #[cfg_attr(rustfmt, rustfmt_skip)]
378 const PARTIAL_LENGTH_1: [u8; 3] = [0x89, 0xFE, 0x01];
380 #[cfg_attr(rustfmt, rustfmt_skip)]
381 const PARTIAL_LENGTH_2: [u8; 6] = [0x89, 0xFF, 0x01, 0x02, 0x03, 0x04];
383 #[cfg_attr(rustfmt, rustfmt_skip)]
384 const PARTIAL_MASK: [u8; 6] = [0x82, 0xFE, 0x01, 0x02, 0x00, 0x00];
386 #[cfg_attr(rustfmt, rustfmt_skip)]
387 const PARTIAL_PAYLOAD: [u8; 8] = [0x82, 0x85, 0x01, 0x02, 0x03, 0x04, 0x00, 0x00];
389
390 #[cfg_attr(rustfmt, rustfmt_skip)]
392 const PING_NO_DATA: [u8; 6] = [0x89, 0x80, 0x00, 0x00, 0x00, 0x01];
393
394 fn decode(buf: &[u8]) -> Result<Option<Frame>, io::Error> {
395 let mut eb = BytesMut::with_capacity(256);
396 eb.extend(buf);
397 let mut fc: FrameCodec = Default::default();
398 fc.set_client(false);
399 fc.decode(&mut eb)
400 }
401
402 #[test]
403 fn decode_partial_header() {
405 if let Ok(None) = decode(&PARTIAL_HEADER) {
406 assert!(true);
407 } else {
408 assert!(false);
409 }
410 }
411
412 #[test]
413 fn decode_partial_len_1() {
415 if let Ok(None) = decode(&PARTIAL_LENGTH_1) {
416 assert!(true);
417 } else {
418 assert!(false);
419 }
420 }
421
422 #[test]
423 fn decode_partial_len_2() {
425 if let Ok(None) = decode(&PARTIAL_LENGTH_2) {
426 assert!(true);
427 } else {
428 assert!(false);
429 }
430 }
431
432 #[test]
433 fn decode_partial_mask() {
435 if let Ok(None) = decode(&PARTIAL_MASK) {
436 assert!(true);
437 } else {
438 assert!(false);
439 }
440 }
441
442 #[test]
443 fn decode_partial_payload() {
445 if let Ok(None) = decode(&PARTIAL_PAYLOAD) {
446 assert!(true);
447 } else {
448 assert!(false);
449 }
450 }
451
452 #[test]
453 fn decode_invalid_control_payload_len() {
455 if let Err(_e) = decode(&CTRL_PAYLOAD_LEN) {
456 assert!(true);
457 } else {
458 assert!(false);
459 }
460 }
461
462 #[test]
463 fn decode_reserved() {
465 let reserved = [0x90, 0xa0, 0xc0];
467
468 for res in &reserved {
469 let mut buf = Vec::with_capacity(2);
470 let mut first_byte = 0_u8;
471 first_byte |= *res;
472 buf.push(first_byte);
473 buf.push(0x00);
474 if let Err(_e) = decode(&buf) {
475 assert!(true);
476 } else {
478 util::stdo(&format!("rsv should not be set: {}", res));
479 assert!(false);
480 }
481 }
482 }
483
484 #[test]
485 fn decode_fragmented_control() {
487 let second_bytes = [8, 9, 10];
488
489 for sb in &second_bytes {
490 let mut buf = Vec::with_capacity(2);
491 let mut first_byte = 0_u8;
492 first_byte |= *sb;
493 buf.push(first_byte);
494 buf.push(0x00);
495 if let Err(_e) = decode(&buf) {
496 assert!(true);
497 } else {
499 util::stdo("control frame {} is marked as fragment");
500 assert!(false);
501 }
502 }
503 }
504
505 #[test]
506 fn decode_reserved_opcodes() {
508 let reserved = [3, 4, 5, 6, 7, 11, 12, 13, 14, 15];
509
510 for res in &reserved {
511 let mut buf = Vec::with_capacity(2);
512 let mut first_byte = 0_u8;
513 first_byte |= 0x80;
514 first_byte |= *res;
515 buf.push(first_byte);
516 buf.push(0x00);
517 if let Err(_e) = decode(&buf) {
518 assert!(true);
519 } else {
521 util::stdo(&format!("opcode {} should be reserved", res));
522 assert!(false);
523 }
524 }
525 }
526
527 #[test]
528 fn decode_no_mask() {
531 if let Err(_e) = decode(&NO_MASK) {
532 assert!(true);
533 } else {
535 util::stdo("decoded frames should always have a mask");
536 assert!(false);
537 }
538 }
539
540 #[test]
541 fn decode_ping_no_data() {
542 if let Ok(Some(frame)) = decode(&PING_NO_DATA) {
543 assert!(frame.fin());
544 assert!(!frame.rsv1());
545 assert!(!frame.rsv2());
546 assert!(!frame.rsv3());
547 assert!(frame.opcode() == OpCode::Ping);
548 assert!(frame.payload_length() == 0);
549 assert!(frame.extension_data().is_none());
550 assert!(frame.application_data().is_empty());
551 } else {
552 assert!(false);
553 }
554 }
555}