use bytes::{Buf, BytesMut};
use tokio_util::codec::Decoder;
use super::types::{Frame, Limits, OpCode, Role};
use crate::{
mask,
proto::ProtocolError,
utf8::{self, Validator},
CloseCode, Error, Payload,
};
const MAX_FRAME_HEADER_SIZE: usize = 14;
#[derive(Debug)]
pub(super) struct WebSocketProtocol {
pub(super) role: Role,
pub(super) limits: Limits,
fragmented_message_opcode: OpCode,
payload_processed: usize,
validator: Validator,
}
impl WebSocketProtocol {
#[cfg(any(feature = "client", feature = "server"))]
pub(super) fn new(role: Role, limits: Limits) -> Self {
Self {
role,
limits,
fragmented_message_opcode: OpCode::Continuation,
payload_processed: 0,
validator: Validator::new(),
}
}
}
macro_rules! get_buf_if_space {
($buf:expr, $range:expr) => {
if let Some(cont) = $buf.get($range) {
cont
} else {
$buf.reserve(MAX_FRAME_HEADER_SIZE - $range.len());
return Ok(None);
}
};
}
impl Decoder for WebSocketProtocol {
type Error = Error;
type Item = Frame;
#[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let first_two_bytes = get_buf_if_space!(src, 0..2);
let fin = first_two_bytes[0] >> 7 != 0;
let rsv = first_two_bytes[0] & 0x70;
if rsv != 0 {
return Err(Error::Protocol(ProtocolError::InvalidRsv));
}
let opcode = OpCode::try_from(first_two_bytes[0] & 0xF)?;
if opcode.is_control() {
if !fin {
return Err(Error::Protocol(ProtocolError::FragmentedControlFrame));
}
} else if self.fragmented_message_opcode == OpCode::Continuation {
if opcode == OpCode::Continuation {
return Err(Error::Protocol(ProtocolError::InvalidOpcode));
}
} else if opcode != OpCode::Continuation {
return Err(Error::Protocol(ProtocolError::InvalidOpcode));
}
let masked = first_two_bytes[1] >> 7 != 0;
if masked && self.role == Role::Client {
return Err(Error::Protocol(ProtocolError::UnexpectedMaskedFrame));
} else if !masked && self.role == Role::Server {
return Err(Error::Protocol(ProtocolError::UnexpectedUnmaskedFrame));
}
let mut payload_length = (first_two_bytes[1] & 127) as usize;
let mut offset = 2;
if opcode == OpCode::Close && payload_length == 1 {
return Err(Error::Protocol(ProtocolError::InvalidPayloadLength));
} else if payload_length > 125 {
if opcode.is_control() {
return Err(Error::Protocol(ProtocolError::InvalidPayloadLength));
}
if payload_length == 126 {
payload_length =
u16::from_be_bytes(get_buf_if_space!(src, 2..4).try_into().unwrap()) as usize;
if payload_length <= 125 {
return Err(Error::Protocol(ProtocolError::InvalidPayloadLength));
}
offset = 4;
} else if payload_length == 127 {
payload_length =
u64::from_be_bytes(get_buf_if_space!(src, 2..10).try_into().unwrap()) as usize;
if u16::try_from(payload_length).is_ok() {
return Err(Error::Protocol(ProtocolError::InvalidPayloadLength));
}
offset = 10;
} else {
debug_assert!(false, "7 bit value expected to be <= 127");
}
}
if payload_length > self.limits.max_payload_len {
return Err(Error::PayloadTooLong {
len: payload_length,
max_len: self.limits.max_payload_len,
});
}
if masked {
offset += 4;
if src.len() < offset {
src.reserve(MAX_FRAME_HEADER_SIZE - 4);
return Ok(None);
}
}
if payload_length != 0 {
let is_text = opcode == OpCode::Text
|| (opcode == OpCode::Continuation
&& self.fragmented_message_opcode == OpCode::Text);
let payload_available = (src.len() - offset).min(payload_length);
let is_complete = payload_available == payload_length;
let payload = if masked && (is_complete || is_text) {
let (l, r) = unsafe { src.split_at_mut_unchecked(offset) };
let mask = unsafe { l.get_unchecked_mut(l.len() - 4..).try_into().unwrap() };
let payload =
unsafe { r.get_unchecked_mut(self.payload_processed..payload_available) };
mask::frame(mask, payload);
payload
} else {
unsafe {
src.get_unchecked_mut(
offset + self.payload_processed..offset + payload_available,
)
}
};
if is_text {
self.validator.feed(payload, is_complete && fin)?;
self.payload_processed = payload_available;
}
if !is_complete {
src.reserve(payload_length - payload_available);
return Ok(None);
}
if opcode == OpCode::Close {
let code = CloseCode::try_from(u16::from_be_bytes(unsafe {
src.get_unchecked(offset..offset + 2).try_into().unwrap()
}))?;
if code.is_reserved() {
return Err(Error::Protocol(ProtocolError::InvalidCloseCode));
}
let _reason = utf8::parse_str(unsafe {
src.get_unchecked(offset + 2..offset + payload_length)
})?;
}
}
src.advance(offset);
let mut payload = Payload::from(src.split_to(payload_length));
payload.set_utf8_validated(opcode == OpCode::Text && fin);
if (fin && opcode == OpCode::Continuation) || (!fin && opcode != OpCode::Continuation) {
self.fragmented_message_opcode = opcode;
}
self.payload_processed = 0;
Ok(Some(Frame {
opcode,
payload,
is_final: fin,
}))
}
}