use std::convert::TryFrom;
use std::mem;
use std::usize;
use byteorder::{BigEndian, ByteOrder, NativeEndian};
use bytes::BytesMut;
use tokio_util::codec::{Decoder, Encoder};
use crate::mask::Mask;
use crate::{Error, Result};
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum DataLength {
Small(u8),
Medium(u16),
Large(u64),
}
impl From<u64> for DataLength {
fn from(n: u64) -> Self {
if n <= 125 {
Self::Small(n as u8)
} else if n <= 65535 {
Self::Medium(n as u16)
} else {
Self::Large(n)
}
}
}
impl TryFrom<DataLength> for u64 {
type Error = Error;
fn try_from(len: DataLength) -> Result<Self> {
match len {
DataLength::Small(n) => Ok(n as u64),
DataLength::Medium(n) => {
if n <= 125 {
return Err(format!("payload length {} should not be represented using 16 bits", n).into());
}
Ok(n as u64)
}
DataLength::Large(n) => {
if n <= 65535 {
return Err(format!("payload length {} should not be represented using 64 bits", n).into());
}
if n >= 0x8000_0000_0000_0000 {
return Err(format!("frame is too long: {} bytes ({:x})", n, n).into());
}
Ok(n as u64)
}
}
}
}
impl From<usize> for DataLength {
fn from(n: usize) -> Self {
Self::from(n as u64)
}
}
impl TryFrom<DataLength> for usize {
type Error = Error;
fn try_from(len: DataLength) -> Result<Self> {
let len = u64::try_from(len)?;
if len > usize::MAX as u64 {
return Err(format!(
"frame of {} bytes can't be parsed on a {}-bit platform",
len,
mem::size_of::<usize>() / 8
)
.into());
}
Ok(len as usize)
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct FrameHeader {
pub(crate) fin: bool,
pub(crate) rsv: u8,
pub(crate) opcode: u8,
pub(crate) mask: Option<Mask>,
pub(crate) data_len: DataLength,
}
impl FrameHeader {
pub fn new(fin: bool, rsv: u8, opcode: u8, mask: Option<Mask>, data_len: DataLength) -> Self {
Self {
fin,
rsv,
opcode,
mask,
data_len,
}
}
pub fn fin(&self) -> bool {
self.fin
}
pub fn rsv(&self) -> u8 {
self.rsv
}
pub fn opcode(&self) -> u8 {
self.opcode
}
pub fn mask(&self) -> Option<Mask> {
self.mask
}
pub fn data_len(&self) -> DataLength {
self.data_len
}
pub fn header_len(&self) -> usize {
let mut len = 1 + 1 ;
len += match self.data_len {
DataLength::Small(_) => 0,
DataLength::Medium(_) => 2,
DataLength::Large(_) => 8,
};
if self.mask.is_some() {
len += 4;
}
len
}
pub(crate) fn parse_slice(buf: &[u8]) -> Option<(Self, usize)> {
if buf.len() < 2 {
return None;
}
let fin_opcode = buf[0];
let mask_data_len = buf[1];
let mut header_len = 2;
let fin = (fin_opcode & 0x80) != 0;
let rsv = (fin_opcode & 0xf0) & !0x80;
let opcode = fin_opcode & 0x0f;
let (buf, data_len) = match mask_data_len & 0x7f {
127 => {
if buf.len() < 10 {
return None;
}
header_len += 8;
(&buf[10..], DataLength::Large(BigEndian::read_u64(&buf[2..10])))
}
126 => {
if buf.len() < 4 {
return None;
}
header_len += 2;
(&buf[4..], DataLength::Medium(BigEndian::read_u16(&buf[2..4])))
}
n => {
assert!(n < 126);
(&buf[2..], DataLength::Small(n))
}
};
let mask = if mask_data_len & 0x80 == 0 {
None
} else {
if buf.len() < 4 {
return None;
}
header_len += 4;
Some(NativeEndian::read_u32(buf).into())
};
let header = Self {
fin,
rsv,
opcode,
mask,
data_len,
};
debug_assert_eq!(header.header_len(), header_len);
Some((header, header_len))
}
pub(crate) fn write_to_slice(&self, dst: &mut [u8]) {
let FrameHeader {
fin,
rsv,
opcode,
mask,
data_len,
} = *self;
let mut fin_opcode = rsv | opcode;
if fin {
fin_opcode |= 0x80
};
dst[0] = fin_opcode;
let mask_bit = if mask.is_some() { 0x80 } else { 0 };
let dst = match data_len {
DataLength::Small(n) => {
dst[1] = mask_bit | n;
&mut dst[2..]
}
DataLength::Medium(n) => {
let (dst, rest) = dst.split_at_mut(4);
dst[1] = mask_bit | 126;
BigEndian::write_u16(&mut dst[2..4], n);
rest
}
DataLength::Large(n) => {
let (dst, rest) = dst.split_at_mut(10);
dst[1] = mask_bit | 127;
BigEndian::write_u64(&mut dst[2..10], n);
rest
}
};
if let Some(mask) = mask {
NativeEndian::write_u32(dst, mask.into());
}
}
pub(crate) fn write_to_bytes(&self, dst: &mut BytesMut) {
let data_len = match self.data_len {
DataLength::Small(n) => n as usize,
DataLength::Medium(n) => n as usize,
DataLength::Large(n) => n as usize,
};
let initial_len = dst.len();
let header_len = self.header_len();
dst.reserve(header_len + data_len);
unsafe {
dst.set_len(initial_len + header_len);
}
let dst_slice = &mut dst[initial_len..(initial_len + header_len)];
self.write_to_slice(dst_slice);
}
}
pub struct FrameHeaderCodec;
impl Decoder for FrameHeaderCodec {
type Item = FrameHeader;
type Error = Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<FrameHeader>> {
use bytes::Buf;
Ok(FrameHeader::parse_slice(src.chunk()).map(|(header, header_len)| {
src.advance(header_len);
header
}))
}
}
impl Encoder<FrameHeader> for FrameHeaderCodec {
type Error = Error;
fn encode(&mut self, item: FrameHeader, dst: &mut BytesMut) -> Result<()> {
self.encode(&item, dst)
}
}
impl<'a> Encoder<&'a FrameHeader> for FrameHeaderCodec {
type Error = Error;
fn encode(&mut self, item: &'a FrameHeader, dst: &mut BytesMut) -> Result<()> {
item.write_to_bytes(dst);
Ok(())
}
}
#[cfg(test)]
mod tests {
use assert_allocations::assert_allocated_bytes;
use bytes::BytesMut;
use tokio_util::codec::{Decoder, Encoder};
use crate::frame::{FrameHeader, FrameHeaderCodec};
#[quickcheck]
fn round_trips(fin: bool, is_text: bool, mask: Option<u32>, data_len: u16) {
let header = assert_allocated_bytes(0, || FrameHeader {
fin,
rsv: 0,
opcode: if is_text { 1 } else { 2 },
mask: mask.map(|n| n.into()),
data_len: (data_len as u64).into(),
});
assert_allocated_bytes((header.header_len() + data_len as usize).max(8), || {
let mut codec = FrameHeaderCodec;
let mut bytes = BytesMut::new();
codec.encode(&header, &mut bytes).unwrap();
let header_len = header.header_len();
assert_eq!(bytes.len(), header_len);
let header2 = codec.decode(&mut bytes).unwrap().unwrap();
assert_eq!(header2.header_len(), header_len);
assert_eq!(bytes.len(), 0);
assert_eq!(header, header2)
})
}
}