use bytes::{Buf, BytesMut};
use tokio_util::codec;
use crate::{
frame::{self, Frame, OpCode, MAX_HEAD_SIZE},
Role, WebSocketError,
};
#[repr(C)]
struct ReadState {
flags: u8,
mask: [u8; 4],
_reserved: [u8; 3],
payload_len: usize,
}
impl ReadState {
#[inline(always)]
fn new(
fin: bool,
rsv1: bool,
opcode: OpCode,
mask: Option<[u8; 4]>,
payload_len: usize,
) -> Self {
let flags = ((fin as u8) << 7)
| ((rsv1 as u8) << 6)
| ((opcode as u8) & 0x0F)
| if mask.is_some() { 0x10 } else { 0 };
Self {
flags,
mask: mask.unwrap_or([0; 4]),
_reserved: [0; 3],
payload_len,
}
}
#[inline(always)]
fn fin(&self) -> bool {
self.flags & 0x80 != 0
}
#[inline(always)]
fn rsv1(&self) -> bool {
self.flags & 0x40 != 0
}
#[inline(always)]
fn opcode(&self) -> OpCode {
unsafe { std::mem::transmute(self.flags & 0x0F) }
}
#[inline(always)]
fn mask(&self) -> Option<[u8; 4]> {
if self.flags & 0x10 != 0 {
Some(self.mask)
} else {
None
}
}
}
pub struct Codec {
decoder: Decoder,
encoder: Encoder,
}
impl From<(Decoder, Encoder)> for Codec {
fn from((decoder, encoder): (Decoder, Encoder)) -> Self {
Self { decoder, encoder }
}
}
impl codec::Decoder for Codec {
type Item = <Decoder as codec::Decoder>::Item;
type Error = <Decoder as codec::Decoder>::Error;
#[inline]
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
self.decoder.decode(src)
}
}
impl codec::Encoder<Frame> for Codec {
type Error = <Encoder as codec::Encoder<Frame>>::Error;
#[inline]
fn encode(&mut self, item: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
self.encoder.encode(item, dst)
}
}
pub struct Decoder {
role: Role,
state: Option<ReadState>,
max_payload_size: usize,
}
impl Decoder {
pub fn new(role: Role, max_payload_size: usize) -> Self {
Self {
role,
state: None,
max_payload_size,
}
}
}
impl codec::Decoder for Decoder {
type Item = Frame;
type Error = WebSocketError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if let Some(state) = self.state.take() {
if src.remaining() < state.payload_len {
self.state = Some(state);
return Ok(None);
}
if self.role == Role::Server {
let Some(mask) = state.mask() else {
return Err(WebSocketError::InvalidFragment);
};
crate::mask::apply_mask(&mut src[..state.payload_len], mask);
}
let payload = src.split_to(state.payload_len).freeze();
let mut frame = Frame::new(state.fin(), state.opcode(), state.mask(), payload);
frame.is_compressed = state.rsv1();
return Ok(Some(frame));
}
if src.remaining() < 2 {
return Ok(None);
}
let fin = src[0] & 0b10000000 != 0;
let rsv1 = src[0] & 0b01000000 != 0;
if src[0] & 0b00110000 != 0 {
return Err(WebSocketError::ReservedBitsNotZero);
}
let opcode = frame::OpCode::try_from(src[0] & 0b00001111)?;
let masked = src[1] & 0b10000000 != 0;
let length_code = src[1] & 0x7F;
let extra = match length_code {
126 => 2,
127 => 8,
_ => 0,
};
let header_size = 2 + extra + (masked as usize * 4);
if src.remaining() < header_size {
return Ok(None);
}
src.advance(2);
let payload_len: usize = match extra {
0 => usize::from(length_code),
2 => src.get_u16() as usize,
#[cfg(target_pointer_width = "64")]
8 => src.get_u64() as usize,
#[cfg(any(target_pointer_width = "16", target_pointer_width = "32"))]
8 => match usize::try_from(src.get_u64()) {
Ok(length) => length,
Err(_) => return Err(WebSocketError::FrameTooLarge),
},
_ => unreachable!(),
};
let mask = if masked {
Some(src.get_u32().to_be_bytes())
} else {
None
};
if opcode.is_control() && !fin {
return Err(WebSocketError::ControlFrameFragmented);
}
if opcode == OpCode::Ping && payload_len > 125 {
return Err(WebSocketError::PingFrameTooLarge);
}
if payload_len >= self.max_payload_size {
return Err(WebSocketError::FrameTooLarge);
}
if src.remaining() < payload_len {
self.state = Some(ReadState::new(fin, rsv1, opcode, mask, payload_len));
return Ok(None);
}
if self.role == Role::Server {
let Some(mask) = mask else {
return Err(WebSocketError::InvalidFragment);
};
crate::mask::apply_mask(&mut src[..payload_len], mask);
}
let payload = src.split_to(payload_len).freeze();
let mut frame = Frame::new(fin, opcode, mask, payload);
frame.is_compressed = rsv1;
Ok(Some(frame))
}
}
pub struct Encoder {
role: Role,
}
impl Encoder {
pub fn new(role: Role) -> Self {
Self { role }
}
}
impl codec::Encoder<Frame> for Encoder {
type Error = WebSocketError;
#[inline(always)]
fn encode(&mut self, mut frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
if self.role == Role::Client {
frame.set_random_mask_if_not_set();
}
let payload_len = frame.payload.len();
dst.reserve(MAX_HEAD_SIZE + payload_len);
frame.write_head(dst);
let index = dst.len();
dst.extend_from_slice(&frame.payload);
if let Some(mask) = frame.mask {
crate::mask::apply_mask(&mut dst[index..], mask);
}
Ok(())
}
}