#[derive(Debug, PartialEq, Eq)]
pub enum Element<'a> {
FrameStart {
fin: bool,
length: u64,
opcode: u8,
},
Data {
data: Data<'a>,
last_in_frame: bool,
},
Error {
desc: &'static str,
},
}
#[derive(Debug, PartialEq, Eq)]
pub struct Data<'a> {
data: &'a [u8],
mask: u32,
offset: u8,
}
pub struct StateMachine {
inner: StateMachineInner,
buffer: Vec<u8>, }
enum StateMachineInner {
InHeader,
InData {
mask: u32,
offset: u8,
remaining_len: u64,
},
}
impl StateMachine {
pub fn new() -> StateMachine {
StateMachine {
inner: StateMachineInner::InHeader,
buffer: Vec::with_capacity(14),
}
}
#[inline]
pub fn feed<'a>(&'a mut self, data: &'a [u8]) -> ElementsIter<'a> {
ElementsIter { state: self, data }
}
}
pub struct ElementsIter<'a> {
state: &'a mut StateMachine,
data: &'a [u8],
}
impl<'a> Iterator for ElementsIter<'a> {
type Item = Element<'a>;
fn next(&mut self) -> Option<Element<'a>> {
if self.data.is_empty() {
return None;
}
match self.state.inner {
StateMachineInner::InHeader => {
let total_buffered = self.state.buffer.len() + self.data.len();
if total_buffered < 6 {
self.state.buffer.extend_from_slice(self.data);
self.data = &[];
return None;
}
let (first_byte, second_byte) = {
let mut mask_iter = self.state.buffer.iter().chain(self.data.iter());
let first_byte = *mask_iter.next().unwrap();
let second_byte = *mask_iter.next().unwrap();
(first_byte, second_byte)
};
if (first_byte & 0x70) != 0 {
return Some(Element::Error {
desc: "Reserved bits must be zero",
});
}
if (second_byte & 0x80) == 0 {
return Some(Element::Error {
desc: "Client-to-server messages must be masked",
});
}
let (length, mask) = match second_byte & 0x7f {
126 => {
if total_buffered < 8 {
self.state.buffer.extend_from_slice(self.data);
self.data = &[];
return None;
}
let mut mask_iter =
self.state.buffer.iter().chain(self.data.iter()).skip(2);
let length = {
let a = u64::from(*mask_iter.next().unwrap());
let b = u64::from(*mask_iter.next().unwrap());
(a << 8) | (b << 0)
};
let mask = {
let a = u32::from(*mask_iter.next().unwrap());
let b = u32::from(*mask_iter.next().unwrap());
let c = u32::from(*mask_iter.next().unwrap());
let d = u32::from(*mask_iter.next().unwrap());
(a << 24) | (b << 16) | (c << 8) | (d << 0)
};
(length, mask)
}
127 => {
if total_buffered < 14 {
self.state.buffer.extend_from_slice(self.data);
self.data = &[];
return None;
}
let mut mask_iter =
self.state.buffer.iter().chain(self.data.iter()).skip(2);
let length = {
let a = u64::from(*mask_iter.next().unwrap());
let b = u64::from(*mask_iter.next().unwrap());
let c = u64::from(*mask_iter.next().unwrap());
let d = u64::from(*mask_iter.next().unwrap());
let e = u64::from(*mask_iter.next().unwrap());
let f = u64::from(*mask_iter.next().unwrap());
let g = u64::from(*mask_iter.next().unwrap());
let h = u64::from(*mask_iter.next().unwrap());
if (a & 0x80) != 0 {
return Some(Element::Error {
desc: "Most-significant bit of the length must be zero",
});
}
(a << 56)
| (b << 48)
| (c << 40)
| (d << 32)
| (e << 24)
| (f << 16)
| (g << 8)
| (h << 0)
};
let mask = {
let a = u32::from(*mask_iter.next().unwrap());
let b = u32::from(*mask_iter.next().unwrap());
let c = u32::from(*mask_iter.next().unwrap());
let d = u32::from(*mask_iter.next().unwrap());
(a << 24) | (b << 16) | (c << 8) | (d << 0)
};
(length, mask)
}
n => {
let mut mask_iter =
self.state.buffer.iter().chain(self.data.iter()).skip(2);
let mask = {
let a = u32::from(*mask_iter.next().unwrap());
let b = u32::from(*mask_iter.next().unwrap());
let c = u32::from(*mask_iter.next().unwrap());
let d = u32::from(*mask_iter.next().unwrap());
(a << 24) | (b << 16) | (c << 8) | (d << 0)
};
(u64::from(n), mask)
}
};
let data_start = {
let data_start_off = match second_byte & 0x7f {
126 => 8,
127 => 14,
_ => 6,
};
assert!(self.state.buffer.len() < data_start_off);
&self.data[(data_start_off - self.state.buffer.len())..]
};
self.data = data_start;
self.state.buffer.clear();
self.state.inner = StateMachineInner::InData {
mask,
remaining_len: length,
offset: 0,
};
Some(Element::FrameStart {
fin: (first_byte & 0x80) != 0,
length,
opcode: first_byte & 0xf,
})
}
StateMachineInner::InData {
mask,
ref mut remaining_len,
ref mut offset,
} if *remaining_len > self.data.len() as u64 => {
let data = Data {
data: self.data,
mask,
offset: *offset,
};
*offset += (self.data.len() % 4) as u8;
*offset %= 4;
*remaining_len -= self.data.len() as u64;
self.data = &[];
Some(Element::Data {
data,
last_in_frame: false,
})
}
StateMachineInner::InData {
mask,
remaining_len,
offset,
} => {
debug_assert!(self.data.len() as u64 >= remaining_len);
let data = Data {
data: &self.data[0..remaining_len as usize],
mask,
offset,
};
self.data = &self.data[remaining_len as usize..];
self.state.inner = StateMachineInner::InHeader;
debug_assert!(self.state.buffer.is_empty());
Some(Element::Data {
data,
last_in_frame: true,
})
}
}
}
}
impl<'a> Iterator for Data<'a> {
type Item = u8;
#[inline]
fn next(&mut self) -> Option<u8> {
if self.data.is_empty() {
return None;
}
let byte = self.data[0];
let mask = ((self.mask >> (3 - self.offset) * 8) & 0xff) as u8;
let decoded = byte ^ mask;
self.data = &self.data[1..];
self.offset = (self.offset + 1) % 4;
Some(decoded)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let l = self.data.len();
(l, Some(l))
}
}
impl<'a> ExactSizeIterator for Data<'a> {}
#[cfg(test)]
mod tests {
use super::Element;
use super::StateMachine;
#[test]
fn basic() {
let mut machine = StateMachine::new();
let data = &[
0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58,
];
let mut iter = machine.feed(data);
assert_eq!(
iter.next().unwrap(),
Element::FrameStart {
fin: true,
length: 5,
opcode: 1
}
);
match iter.next().unwrap() {
Element::Data {
data,
last_in_frame,
} => {
assert!(last_in_frame);
assert_eq!(data.collect::<Vec<_>>(), b"Hello");
}
_ => panic!(),
}
}
}