1use std::default::Default;
2use std::fmt;
3use std::io::{Cursor, ErrorKind, Read, Write};
4
5use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
6use rand;
7
8use protocol::{CloseCode, OpCode};
9use result::{Error, Kind, Result};
10use stream::TryReadBuf;
11
12fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) {
13 let iter = buf.iter_mut().zip(mask.iter().cycle());
14 for (byte, &key) in iter {
15 *byte ^= key
16 }
17}
18
19#[derive(Debug, Clone)]
21pub struct Frame {
22 finished: bool,
23 rsv1: bool,
24 rsv2: bool,
25 rsv3: bool,
26 opcode: OpCode,
27
28 mask: Option<[u8; 4]>,
29
30 payload: Vec<u8>,
31}
32
33impl Frame {
34 #[inline]
37 pub fn len(&self) -> usize {
38 let mut header_length = 2;
39 let payload_len = self.payload().len();
40 if payload_len > 125 {
41 if payload_len <= u16::max_value() as usize {
42 header_length += 2;
43 } else {
44 header_length += 8;
45 }
46 }
47
48 if self.is_masked() {
49 header_length += 4;
50 }
51
52 header_length + payload_len
53 }
54
55 #[inline]
57 pub fn is_empty(&self) -> bool {
58 false
59 }
60
61 #[inline]
63 pub fn is_final(&self) -> bool {
64 self.finished
65 }
66
67 #[inline]
69 pub fn has_rsv1(&self) -> bool {
70 self.rsv1
71 }
72
73 #[inline]
75 pub fn has_rsv2(&self) -> bool {
76 self.rsv2
77 }
78
79 #[inline]
81 pub fn has_rsv3(&self) -> bool {
82 self.rsv3
83 }
84
85 #[inline]
87 pub fn opcode(&self) -> OpCode {
88 self.opcode
89 }
90
91 #[inline]
93 pub fn is_control(&self) -> bool {
94 self.opcode.is_control()
95 }
96
97 #[inline]
99 pub fn payload(&self) -> &Vec<u8> {
100 &self.payload
101 }
102
103 #[doc(hidden)]
105 #[inline]
106 pub fn is_masked(&self) -> bool {
107 self.mask.is_some()
108 }
109
110 #[doc(hidden)]
112 #[allow(dead_code)]
113 #[inline]
114 pub fn mask(&self) -> Option<&[u8; 4]> {
115 self.mask.as_ref()
116 }
117
118 #[allow(dead_code)]
120 #[inline]
121 pub fn set_final(&mut self, is_final: bool) -> &mut Frame {
122 self.finished = is_final;
123 self
124 }
125
126 #[inline]
128 pub fn set_rsv1(&mut self, has_rsv1: bool) -> &mut Frame {
129 self.rsv1 = has_rsv1;
130 self
131 }
132
133 #[inline]
135 pub fn set_rsv2(&mut self, has_rsv2: bool) -> &mut Frame {
136 self.rsv2 = has_rsv2;
137 self
138 }
139
140 #[inline]
142 pub fn set_rsv3(&mut self, has_rsv3: bool) -> &mut Frame {
143 self.rsv3 = has_rsv3;
144 self
145 }
146
147 #[allow(dead_code)]
149 #[inline]
150 pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Frame {
151 self.opcode = opcode;
152 self
153 }
154
155 #[allow(dead_code)]
157 #[inline]
158 pub fn payload_mut(&mut self) -> &mut Vec<u8> {
159 &mut self.payload
160 }
161
162 #[doc(hidden)]
168 #[inline]
169 pub fn set_mask(&mut self) -> &mut Frame {
170 self.mask = Some(rand::random());
171 self
172 }
173
174 #[doc(hidden)]
177 #[inline]
178 pub fn remove_mask(&mut self) -> &mut Frame {
179 self.mask
180 .take()
181 .map(|mask| apply_mask(&mut self.payload, &mask));
182 self
183 }
184
185 pub fn into_data(self) -> Vec<u8> {
187 self.payload
188 }
189
190 #[inline]
192 pub fn message(data: Vec<u8>, code: OpCode, finished: bool) -> Frame {
193 debug_assert!(
194 match code {
195 OpCode::Text | OpCode::Binary | OpCode::Continue => true,
196 _ => false,
197 },
198 "Invalid opcode for data frame."
199 );
200
201 Frame {
202 finished,
203 opcode: code,
204 payload: data,
205 ..Frame::default()
206 }
207 }
208
209 #[inline]
211 pub fn pong(data: Vec<u8>) -> Frame {
212 Frame {
213 opcode: OpCode::Pong,
214 payload: data,
215 ..Frame::default()
216 }
217 }
218
219 #[inline]
221 pub fn ping(data: Vec<u8>) -> Frame {
222 Frame {
223 opcode: OpCode::Ping,
224 payload: data,
225 ..Frame::default()
226 }
227 }
228
229 #[inline]
231 pub fn close(code: CloseCode, reason: &str) -> Frame {
232 let payload = if let CloseCode::Empty = code {
233 Vec::new()
234 } else {
235 let u: u16 = code.into();
236 let raw = [(u >> 8) as u8, u as u8];
237 [&raw, reason.as_bytes()].concat()
238 };
239
240 Frame {
241 payload,
242 ..Frame::default()
243 }
244 }
245
246 pub fn parse(cursor: &mut Cursor<Vec<u8>>, max_payload_length: u64) -> Result<Option<Frame>> {
248 let size = cursor.get_ref().len() as u64 - cursor.position();
249 let initial = cursor.position();
250 trace!("Position in buffer {}", initial);
251
252 let mut head = [0u8; 2];
253 if cursor.read(&mut head)? != 2 {
254 cursor.set_position(initial);
255 return Ok(None);
256 }
257
258 trace!("Parsed headers {:?}", head);
259
260 let first = head[0];
261 let second = head[1];
262 trace!("First: {:b}", first);
263 trace!("Second: {:b}", second);
264
265 let finished = first & 0x80 != 0;
266
267 let rsv1 = first & 0x40 != 0;
268 let rsv2 = first & 0x20 != 0;
269 let rsv3 = first & 0x10 != 0;
270
271 let opcode = OpCode::from(first & 0x0F);
272 trace!("Opcode: {:?}", opcode);
273
274 let masked = second & 0x80 != 0;
275 trace!("Masked: {:?}", masked);
276
277 let mut header_length = 2;
278
279 let mut length = u64::from(second & 0x7F);
280
281 if let Some(length_nbytes) = match length {
282 126 => Some(2),
283 127 => Some(8),
284 _ => None,
285 } {
286 match cursor.read_uint::<BigEndian>(length_nbytes) {
287 Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => {
288 cursor.set_position(initial);
289 return Ok(None);
290 }
291 Err(err) => {
292 return Err(Error::from(err));
293 }
294 Ok(read) => {
295 length = read;
296 }
297 };
298 header_length += length_nbytes as u64;
299 }
300 trace!("Payload length: {}", length);
301
302 if length > max_payload_length {
303 return Err(Error::new(
304 Kind::Protocol,
305 format!(
306 "Rejected frame with payload length exceeding defined max: {}.",
307 max_payload_length
308 ),
309 ));
310 }
311
312 let mask = if masked {
313 let mut mask_bytes = [0u8; 4];
314 if cursor.read(&mut mask_bytes)? != 4 {
315 cursor.set_position(initial);
316 return Ok(None);
317 } else {
318 header_length += 4;
319 Some(mask_bytes)
320 }
321 } else {
322 None
323 };
324
325 match length.checked_add(header_length) {
326 Some(l) if size < l => {
327 cursor.set_position(initial);
328 return Ok(None);
329 }
330 Some(_) => (),
331 None => return Ok(None),
332 };
333
334 let mut data = Vec::with_capacity(length as usize);
335 if length > 0 {
336 if let Some(read) = cursor.try_read_buf(&mut data)? {
337 debug_assert!(read == length as usize, "Read incorrect payload length!");
338 }
339 }
340
341 if let OpCode::Bad = opcode {
343 return Err(Error::new(
344 Kind::Protocol,
345 format!("Encountered invalid opcode: {}", first & 0x0F),
346 ));
347 }
348
349 match opcode {
351 OpCode::Ping | OpCode::Pong if length > 125 => {
352 return Err(Error::new(
353 Kind::Protocol,
354 format!(
355 "Rejected WebSocket handshake.Received control frame with length: {}.",
356 length
357 ),
358 ))
359 }
360 OpCode::Close if length > 125 => {
361 debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
362 return Ok(Some(Frame::close(
363 CloseCode::Protocol,
364 "Received close frame with payload length exceeding 125.",
365 )));
366 }
367 _ => (),
368 }
369
370 let frame = Frame {
371 finished,
372 rsv1,
373 rsv2,
374 rsv3,
375 opcode,
376 mask,
377 payload: data,
378 };
379
380 Ok(Some(frame))
381 }
382
383 pub fn format<W>(&mut self, w: &mut W) -> Result<()>
385 where
386 W: Write,
387 {
388 let mut one = 0u8;
389 let code: u8 = self.opcode.into();
390 if self.is_final() {
391 one |= 0x80;
392 }
393 if self.has_rsv1() {
394 one |= 0x40;
395 }
396 if self.has_rsv2() {
397 one |= 0x20;
398 }
399 if self.has_rsv3() {
400 one |= 0x10;
401 }
402 one |= code;
403
404 let mut two = 0u8;
405 if self.is_masked() {
406 two |= 0x80;
407 }
408
409 match self.payload.len() {
410 len if len < 126 => {
411 two |= len as u8;
412 }
413 len if len <= 65535 => {
414 two |= 126;
415 }
416 _ => {
417 two |= 127;
418 }
419 }
420 w.write_all(&[one, two])?;
421
422 if let Some(length_bytes) = match self.payload.len() {
423 len if len < 126 => None,
424 len if len <= 65535 => Some(2),
425 _ => Some(8),
426 } {
427 w.write_uint::<BigEndian>(self.payload.len() as u64, length_bytes)?;
428 }
429
430 if self.is_masked() {
431 let mask = self.mask.take().unwrap();
432 apply_mask(&mut self.payload, &mask);
433 w.write_all(&mask)?;
434 }
435
436 w.write_all(&self.payload)?;
437 Ok(())
438 }
439}
440
441impl Default for Frame {
442 fn default() -> Frame {
443 Frame {
444 finished: true,
445 rsv1: false,
446 rsv2: false,
447 rsv3: false,
448 opcode: OpCode::Close,
449 mask: None,
450 payload: Vec::new(),
451 }
452 }
453}
454
455impl fmt::Display for Frame {
456 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
457 write!(
458 f,
459 "
460<FRAME>
461final: {}
462reserved: {} {} {}
463opcode: {}
464length: {}
465payload length: {}
466payload: 0x{}
467 ",
468 self.finished,
469 self.rsv1,
470 self.rsv2,
471 self.rsv3,
472 self.opcode,
473 self.len(),
475 self.payload.len(),
476 self.payload
477 .iter()
478 .map(|byte| format!("{:x}", byte))
479 .collect::<String>()
480 )
481 }
482}
483
484mod test {
485 #![allow(unused_imports, unused_variables, dead_code)]
486 use super::*;
487 use protocol::OpCode;
488
489 #[test]
490 fn display_frame() {
491 let f = Frame::message("hi there".into(), OpCode::Text, true);
492 let view = format!("{}", f);
493 view.contains("payload:");
494 }
495}