1use std::default::Default;
2use std::fmt;
3use std::io::{Cursor, ErrorKind, Read, Write};
4
5use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
6use rand;
7
8use protocol::{CloseCode, OpCode};
9use result::{Error, Kind, Result};
10use stream::TryReadBuf;
11
12fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) {
13 let iter = buf.iter_mut().zip(mask.iter().cycle());
14 for (byte, &key) in iter {
15 *byte ^= key
16 }
17}
18
19const MAX_FRAME_PAYLOAD_SIZE: u64 = 5 * 1024 * 1024;
20
21#[derive(Debug, Clone)]
23pub struct Frame {
24 finished: bool,
25 rsv1: bool,
26 rsv2: bool,
27 rsv3: bool,
28 opcode: OpCode,
29
30 mask: Option<[u8; 4]>,
31
32 payload: Vec<u8>,
33}
34
35impl Frame {
36 #[inline]
39 pub fn len(&self) -> usize {
40 let mut header_length = 2;
41 let payload_len = self.payload().len();
42 if payload_len > 125 {
43 if payload_len <= u16::max_value() as usize {
44 header_length += 2;
45 } else {
46 header_length += 8;
47 }
48 }
49
50 if self.is_masked() {
51 header_length += 4;
52 }
53
54 header_length + payload_len
55 }
56
57 #[inline]
59 pub fn is_final(&self) -> bool {
60 self.finished
61 }
62
63 #[inline]
65 pub fn has_rsv1(&self) -> bool {
66 self.rsv1
67 }
68
69 #[inline]
71 pub fn has_rsv2(&self) -> bool {
72 self.rsv2
73 }
74
75 #[inline]
77 pub fn has_rsv3(&self) -> bool {
78 self.rsv3
79 }
80
81 #[inline]
83 pub fn opcode(&self) -> OpCode {
84 self.opcode
85 }
86
87 #[inline]
89 pub fn is_control(&self) -> bool {
90 self.opcode.is_control()
91 }
92
93 #[inline]
95 pub fn payload(&self) -> &Vec<u8> {
96 &self.payload
97 }
98
99 #[doc(hidden)]
101 #[inline]
102 pub fn is_masked(&self) -> bool {
103 self.mask.is_some()
104 }
105
106 #[doc(hidden)]
108 #[allow(dead_code)]
109 #[inline]
110 pub fn mask(&self) -> Option<&[u8; 4]> {
111 self.mask.as_ref()
112 }
113
114 #[allow(dead_code)]
116 #[inline]
117 pub fn set_final(&mut self, is_final: bool) -> &mut Frame {
118 self.finished = is_final;
119 self
120 }
121
122 #[inline]
124 pub fn set_rsv1(&mut self, has_rsv1: bool) -> &mut Frame {
125 self.rsv1 = has_rsv1;
126 self
127 }
128
129 #[inline]
131 pub fn set_rsv2(&mut self, has_rsv2: bool) -> &mut Frame {
132 self.rsv2 = has_rsv2;
133 self
134 }
135
136 #[inline]
138 pub fn set_rsv3(&mut self, has_rsv3: bool) -> &mut Frame {
139 self.rsv3 = has_rsv3;
140 self
141 }
142
143 #[allow(dead_code)]
145 #[inline]
146 pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Frame {
147 self.opcode = opcode;
148 self
149 }
150
151 #[allow(dead_code)]
153 #[inline]
154 pub fn payload_mut(&mut self) -> &mut Vec<u8> {
155 &mut self.payload
156 }
157
158 #[doc(hidden)]
164 #[inline]
165 pub fn set_mask(&mut self) -> &mut Frame {
166 self.mask = Some(rand::random());
167 self
168 }
169
170 #[doc(hidden)]
173 #[inline]
174 pub fn remove_mask(&mut self) -> &mut Frame {
175 self.mask
176 .and_then(|mask| Some(apply_mask(&mut self.payload, &mask)));
177 self.mask = None;
178 self
179 }
180
181 pub fn into_data(self) -> Vec<u8> {
183 self.payload
184 }
185
186 #[inline]
188 pub fn message(data: Vec<u8>, code: OpCode, finished: bool) -> Frame {
189 debug_assert!(
190 match code {
191 OpCode::Text | OpCode::Binary | OpCode::Continue => true,
192 _ => false,
193 },
194 "Invalid opcode for data frame."
195 );
196
197 Frame {
198 finished: finished,
199 opcode: code,
200 payload: data,
201 ..Frame::default()
202 }
203 }
204
205 #[inline]
207 pub fn pong(data: Vec<u8>) -> Frame {
208 Frame {
209 opcode: OpCode::Pong,
210 payload: data,
211 ..Frame::default()
212 }
213 }
214
215 #[inline]
217 pub fn ping(data: Vec<u8>) -> Frame {
218 Frame {
219 opcode: OpCode::Ping,
220 payload: data,
221 ..Frame::default()
222 }
223 }
224
225 #[inline]
227 pub fn close(code: CloseCode, reason: &str) -> Frame {
228 let payload = if let CloseCode::Empty = code {
229 Vec::new()
230 } else {
231 let u: u16 = code.into();
232 let raw = [(u >> 8) as u8, u as u8];
233 [&raw, reason.as_bytes()].concat()
234 };
235
236 Frame {
237 payload: payload,
238 ..Frame::default()
239 }
240 }
241
242 pub fn parse(cursor: &mut Cursor<Vec<u8>>) -> Result<Option<Frame>> {
244 let size = cursor.get_ref().len() as u64 - cursor.position();
245 let initial = cursor.position();
246 trace!("Position in buffer {}", initial);
247
248 let mut head = [0u8; 2];
249 if try!(cursor.read(&mut head)) != 2 {
250 cursor.set_position(initial);
251 return Ok(None);
252 }
253
254 trace!("Parsed headers {:?}", head);
255
256 let first = head[0];
257 let second = head[1];
258 trace!("First: {:b}", first);
259 trace!("Second: {:b}", second);
260
261 let finished = first & 0x80 != 0;
262
263 let rsv1 = first & 0x40 != 0;
264 let rsv2 = first & 0x20 != 0;
265 let rsv3 = first & 0x10 != 0;
266
267 let opcode = OpCode::from(first & 0x0F);
268 trace!("Opcode: {:?}", opcode);
269
270 let masked = second & 0x80 != 0;
271 trace!("Masked: {:?}", masked);
272
273 let mut header_length = 2;
274
275 let mut length = (second & 0x7F) as u64;
276
277 if let Some(length_nbytes) = match length {
278 126 => Some(2),
279 127 => Some(8),
280 _ => None,
281 } {
282 match cursor.read_uint::<BigEndian>(length_nbytes) {
283 Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => {
284 cursor.set_position(initial);
285 return Ok(None);
286 }
287 Err(err) => {
288 return Err(Error::from(err));
289 }
290 Ok(read) => {
291 length = read;
292 }
293 };
294 header_length += length_nbytes as u64;
295 }
296
297 trace!("Payload length: {}", length);
298
299 if length > MAX_FRAME_PAYLOAD_SIZE {
300 warn!(
301 "Payload exceeds max: {} > {}. Rejecting.",
302 length, MAX_FRAME_PAYLOAD_SIZE
303 );
304 return Err(Error::new(
305 Kind::Protocol,
306 format!("Rejecting too large payload: {}", length),
307 ));
308 }
309
310 let mask = if masked {
311 let mut mask_bytes = [0u8; 4];
312 if try!(cursor.read(&mut mask_bytes)) != 4 {
313 cursor.set_position(initial);
314 return Ok(None);
315 } else {
316 header_length += 4;
317 Some(mask_bytes)
318 }
319 } else {
320 None
321 };
322
323 match length.checked_add(header_length) {
324 Some(l) if size < l => {
325 cursor.set_position(initial);
326 return Ok(None);
327 }
328 Some(_) => (),
329 None => return Ok(None),
330 };
331
332 let mut data = Vec::with_capacity(length as usize);
333 if length > 0 {
334 if let Some(read) = try!(cursor.try_read_buf(&mut data)) {
335 debug_assert!(read == length as usize, "Read incorrect payload length!");
336 }
337 }
338
339 if let OpCode::Bad = opcode {
341 return Err(Error::new(
342 Kind::Protocol,
343 format!("Encountered invalid opcode: {}", first & 0x0F),
344 ));
345 }
346
347 match opcode {
349 OpCode::Ping | OpCode::Pong if length > 125 => {
350 return Err(Error::new(
351 Kind::Protocol,
352 format!(
353 "Rejected WebSocket handshake.Received control frame with length: {}.",
354 length
355 ),
356 ))
357 }
358 OpCode::Close if length > 125 => {
359 debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
360 return Ok(Some(Frame::close(
361 CloseCode::Protocol,
362 "Received close frame with payload length exceeding 125.",
363 )));
364 }
365 _ => (),
366 }
367
368 let frame = Frame {
369 finished: finished,
370 rsv1: rsv1,
371 rsv2: rsv2,
372 rsv3: rsv3,
373 opcode: opcode,
374 mask: mask,
375 payload: data,
376 };
377
378 Ok(Some(frame))
379 }
380
381 pub fn format<W>(&mut self, w: &mut W) -> Result<()>
383 where
384 W: Write,
385 {
386 let mut one = 0u8;
387 let code: u8 = self.opcode.into();
388 if self.is_final() {
389 one |= 0x80;
390 }
391 if self.has_rsv1() {
392 one |= 0x40;
393 }
394 if self.has_rsv2() {
395 one |= 0x20;
396 }
397 if self.has_rsv3() {
398 one |= 0x10;
399 }
400 one |= code;
401
402 let mut two = 0u8;
403 if self.is_masked() {
404 two |= 0x80;
405 }
406
407 match self.payload.len() {
408 len if len < 126 => {
409 two |= len as u8;
410 }
411 len if len <= 65535 => {
412 two |= 126;
413 }
414 _ => {
415 two |= 127;
416 }
417 }
418 try!(w.write(&[one, two]));
419
420 if let Some(length_bytes) = match self.payload.len() {
421 len if len < 126 => None,
422 len if len <= 65535 => Some(2),
423 _ => Some(8),
424 } {
425 try!(w.write_uint::<BigEndian>(self.payload.len() as u64, length_bytes));
426 }
427
428 if self.is_masked() {
429 let mask = self.mask.take().unwrap();
430 apply_mask(&mut self.payload, &mask);
431 try!(w.write(&mask));
432 }
433
434 try!(w.write(&self.payload));
435 Ok(())
436 }
437}
438
439impl Default for Frame {
440 fn default() -> Frame {
441 Frame {
442 finished: true,
443 rsv1: false,
444 rsv2: false,
445 rsv3: false,
446 opcode: OpCode::Close,
447 mask: None,
448 payload: Vec::new(),
449 }
450 }
451}
452
453impl fmt::Display for Frame {
454 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
455 write!(
456 f,
457 "
458<FRAME>
459final: {}
460reserved: {} {} {}
461opcode: {}
462length: {}
463payload length: {}
464payload: 0x{}
465 ",
466 self.finished,
467 self.rsv1,
468 self.rsv2,
469 self.rsv3,
470 self.opcode,
471 self.len(),
473 self.payload.len(),
474 self.payload
475 .iter()
476 .map(|byte| format!("{:x}", byte))
477 .collect::<String>()
478 )
479 }
480}
481
482mod test {
483 #![allow(unused_imports, unused_variables, dead_code)]
484 use super::*;
485 use protocol::OpCode;
486
487 #[test]
488 fn display_frame() {
489 let f = Frame::message("hi there".into(), OpCode::Text, true);
490 let view = format!("{}", f);
491 view.contains("payload:");
492 }
493}