1use std::default::Default;
2use std::fmt;
3use std::io::{ErrorKind, Read, Write};
4
5use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
6use bytes::Buf;
7use rand;
8
9use circular_buffer::CircularBuffer;
10use protocol::{CloseCode, OpCode};
11use result::{Error, Kind, Result};
12
13fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) {
14 let iter = buf.iter_mut().zip(mask.iter().cycle());
15 for (byte, &key) in iter {
16 *byte ^= key
17 }
18}
19
20#[derive(Debug, Clone)]
22pub struct Frame {
23 finished: bool,
24 rsv1: bool,
25 rsv2: bool,
26 rsv3: bool,
27 opcode: OpCode,
28
29 mask: Option<[u8; 4]>,
30
31 payload: Vec<u8>,
32}
33
34impl Frame {
35 #[inline]
38 pub fn len(&self) -> usize {
39 let mut header_length = 2;
40 let payload_len = self.payload().len();
41 if payload_len > 125 {
42 if payload_len <= u16::max_value() as usize {
43 header_length += 2;
44 } else {
45 header_length += 8;
46 }
47 }
48
49 if self.is_masked() {
50 header_length += 4;
51 }
52
53 header_length + payload_len
54 }
55
56 #[inline]
58 pub fn is_empty(&self) -> bool {
59 false
60 }
61
62 #[inline]
64 pub fn is_final(&self) -> bool {
65 self.finished
66 }
67
68 #[inline]
70 pub fn has_rsv1(&self) -> bool {
71 self.rsv1
72 }
73
74 #[inline]
76 pub fn has_rsv2(&self) -> bool {
77 self.rsv2
78 }
79
80 #[inline]
82 pub fn has_rsv3(&self) -> bool {
83 self.rsv3
84 }
85
86 #[inline]
88 pub fn opcode(&self) -> OpCode {
89 self.opcode
90 }
91
92 #[inline]
94 pub fn is_control(&self) -> bool {
95 self.opcode.is_control()
96 }
97
98 #[inline]
100 pub fn payload(&self) -> &Vec<u8> {
101 &self.payload
102 }
103
104 #[doc(hidden)]
106 #[inline]
107 pub fn is_masked(&self) -> bool {
108 self.mask.is_some()
109 }
110
111 #[doc(hidden)]
113 #[allow(dead_code)]
114 #[inline]
115 pub fn mask(&self) -> Option<&[u8; 4]> {
116 self.mask.as_ref()
117 }
118
119 #[allow(dead_code)]
121 #[inline]
122 pub fn set_final(&mut self, is_final: bool) -> &mut Frame {
123 self.finished = is_final;
124 self
125 }
126
127 #[inline]
129 pub fn set_rsv1(&mut self, has_rsv1: bool) -> &mut Frame {
130 self.rsv1 = has_rsv1;
131 self
132 }
133
134 #[inline]
136 pub fn set_rsv2(&mut self, has_rsv2: bool) -> &mut Frame {
137 self.rsv2 = has_rsv2;
138 self
139 }
140
141 #[inline]
143 pub fn set_rsv3(&mut self, has_rsv3: bool) -> &mut Frame {
144 self.rsv3 = has_rsv3;
145 self
146 }
147
148 #[allow(dead_code)]
150 #[inline]
151 pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Frame {
152 self.opcode = opcode;
153 self
154 }
155
156 #[allow(dead_code)]
158 #[inline]
159 pub fn payload_mut(&mut self) -> &mut Vec<u8> {
160 &mut self.payload
161 }
162
163 #[doc(hidden)]
169 #[inline]
170 pub fn set_mask(&mut self) -> &mut Frame {
171 self.mask = Some(rand::random());
172 self
173 }
174
175 #[doc(hidden)]
178 #[inline]
179 pub fn remove_mask(&mut self) -> &mut Frame {
180 self.mask
181 .take()
182 .map(|mask| apply_mask(&mut self.payload, &mask));
183 self
184 }
185
186 pub fn into_data(self) -> Vec<u8> {
188 self.payload
189 }
190
191 #[inline]
193 pub fn message(data: Vec<u8>, code: OpCode, finished: bool) -> Frame {
194 debug_assert!(
195 match code {
196 OpCode::Text | OpCode::Binary | OpCode::Continue => true,
197 _ => false,
198 },
199 "Invalid opcode for data frame."
200 );
201
202 Frame {
203 finished,
204 opcode: code,
205 payload: data,
206 ..Frame::default()
207 }
208 }
209
210 #[inline]
212 pub fn pong(data: Vec<u8>) -> Frame {
213 Frame {
214 opcode: OpCode::Pong,
215 payload: data,
216 ..Frame::default()
217 }
218 }
219
220 #[inline]
222 pub fn ping(data: Vec<u8>) -> Frame {
223 Frame {
224 opcode: OpCode::Ping,
225 payload: data,
226 ..Frame::default()
227 }
228 }
229
230 #[inline]
232 pub fn close(code: CloseCode, reason: &str) -> Frame {
233 let payload = if let CloseCode::Empty = code {
234 Vec::new()
235 } else {
236 let u: u16 = code.into();
237 let raw = [(u >> 8) as u8, u as u8];
238 [&raw, reason.as_bytes()].concat()
239 };
240
241 Frame {
242 payload,
243 ..Frame::default()
244 }
245 }
246
247 pub fn parse(cursor: &mut CircularBuffer, max_payload_length: u64) -> Result<Option<Frame>> {
249 let size = cursor.remaining();
250 let initial = cursor.read_cursor();
251 trace!("Position in buffer {:?}", initial);
252
253 let mut head = [0u8; 2];
254 if cursor.read(&mut head)? != 2 {
255 cursor.set_read_cursor(initial);
256 return Ok(None);
257 }
258
259 trace!("Parsed headers {:?}", head);
260
261 let first = head[0];
262 let second = head[1];
263 trace!("First: {:b}", first);
264 trace!("Second: {:b}", second);
265
266 let finished = first & 0x80 != 0;
267
268 let rsv1 = first & 0x40 != 0;
269 let rsv2 = first & 0x20 != 0;
270 let rsv3 = first & 0x10 != 0;
271
272 let opcode = OpCode::from(first & 0x0F);
273 trace!("Opcode: {:?}", opcode);
274
275 let masked = second & 0x80 != 0;
276 trace!("Masked: {:?}", masked);
277
278 let mut header_length = 2;
279
280 let mut length = u64::from(second & 0x7F);
281
282 if let Some(length_nbytes) = match length {
283 126 => Some(2),
284 127 => Some(8),
285 _ => None,
286 } {
287 match cursor.read_uint::<BigEndian>(length_nbytes) {
288 Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => {
289 cursor.set_read_cursor(initial);
290 return Ok(None);
291 }
292 Err(err) => {
293 return Err(Error::from(err));
294 }
295 Ok(read) => {
296 length = read;
297 }
298 };
299 header_length += length_nbytes as u64;
300 }
301 trace!("Payload length: {}", length);
302
303 if length > max_payload_length {
304 return Err(Error::new(
305 Kind::Protocol,
306 format!(
307 "Rejected frame with payload length exceeding defined max: {}.",
308 max_payload_length
309 ),
310 ));
311 }
312
313 let mask = if masked {
314 let mut mask_bytes = [0u8; 4];
315 if cursor.read(&mut mask_bytes)? != 4 {
316 cursor.set_read_cursor(initial);
317 return Ok(None);
318 } else {
319 header_length += 4;
320 Some(mask_bytes)
321 }
322 } else {
323 None
324 };
325
326 match length.checked_add(header_length) {
327 Some(l) if (size as u64) < l => {
328 cursor.set_read_cursor(initial);
329 return Ok(None);
330 }
331 Some(_) => (),
332 None => return Ok(None),
333 };
334
335 let data = cursor.read_exact_into_vec(length as usize);
336
337 if let OpCode::Bad = opcode {
339 return Err(Error::new(
340 Kind::Protocol,
341 format!("Encountered invalid opcode: {}", first & 0x0F),
342 ));
343 }
344
345 match opcode {
347 OpCode::Ping | OpCode::Pong if length > 125 => {
348 return Err(Error::new(
349 Kind::Protocol,
350 format!(
351 "Rejected WebSocket handshake.Received control frame with length: {}.",
352 length
353 ),
354 ))
355 }
356 OpCode::Close if length > 125 => {
357 debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
358 return Ok(Some(Frame::close(
359 CloseCode::Protocol,
360 "Received close frame with payload length exceeding 125.",
361 )));
362 }
363 _ => (),
364 }
365
366 let frame = Frame {
367 finished,
368 rsv1,
369 rsv2,
370 rsv3,
371 opcode,
372 mask,
373 payload: data,
374 };
375
376 Ok(Some(frame))
377 }
378
379 pub fn format<W>(&mut self, w: &mut W) -> Result<()>
381 where
382 W: Write,
383 {
384 let mut one = 0u8;
385 let code: u8 = self.opcode.into();
386 if self.is_final() {
387 one |= 0x80;
388 }
389 if self.has_rsv1() {
390 one |= 0x40;
391 }
392 if self.has_rsv2() {
393 one |= 0x20;
394 }
395 if self.has_rsv3() {
396 one |= 0x10;
397 }
398 one |= code;
399
400 let mut two = 0u8;
401 if self.is_masked() {
402 two |= 0x80;
403 }
404
405 match self.payload.len() {
406 len if len < 126 => {
407 two |= len as u8;
408 }
409 len if len <= 65535 => {
410 two |= 126;
411 }
412 _ => {
413 two |= 127;
414 }
415 }
416 w.write_all(&[one, two])?;
417
418 if let Some(length_bytes) = match self.payload.len() {
419 len if len < 126 => None,
420 len if len <= 65535 => Some(2),
421 _ => Some(8),
422 } {
423 w.write_uint::<BigEndian>(self.payload.len() as u64, length_bytes)?;
424 }
425
426 if self.is_masked() {
427 let mask = self.mask.take().unwrap();
428 apply_mask(&mut self.payload, &mask);
429 w.write_all(&mask)?;
430 }
431
432 w.write_all(&self.payload)?;
433 Ok(())
434 }
435}
436
437impl Default for Frame {
438 fn default() -> Frame {
439 Frame {
440 finished: true,
441 rsv1: false,
442 rsv2: false,
443 rsv3: false,
444 opcode: OpCode::Close,
445 mask: None,
446 payload: Vec::new(),
447 }
448 }
449}
450
451impl fmt::Display for Frame {
452 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
453 write!(
454 f,
455 "
456<FRAME>
457final: {}
458reserved: {} {} {}
459opcode: {}
460length: {}
461payload length: {}
462payload: 0x{}
463 ",
464 self.finished,
465 self.rsv1,
466 self.rsv2,
467 self.rsv3,
468 self.opcode,
469 self.len(),
471 self.payload.len(),
472 self.payload
473 .iter()
474 .map(|byte| format!("{:x}", byte))
475 .collect::<String>()
476 )
477 }
478}
479
480mod test {
481 #![allow(unused_imports, unused_variables, dead_code)]
482 use super::*;
483 use protocol::OpCode;
484
485 #[test]
486 fn display_frame() {
487 let f = Frame::message("hi there".into(), OpCode::Text, true);
488 let view = format!("{}", f);
489 view.contains("payload:");
490 }
491}