use std::hint::unreachable_unchecked;
use bytes::{Buf, BytesMut};
use tokio_util::codec::{Decoder, Encoder};
use super::types::{Frame, Limits, OpCode, Role};
use crate::{
mask,
proto::ProtocolError,
utf8::{self, Validator},
CloseCode, Error,
};
#[derive(Debug)]
pub(super) struct WebsocketProtocol {
pub(super) role: Role,
pub(super) limits: Limits,
fragmented_message_opcode: OpCode,
payload_processed: usize,
validator: Validator,
}
impl WebsocketProtocol {
#[cfg(any(feature = "client", feature = "server"))]
pub(super) fn new(role: Role, limits: Limits) -> Self {
Self {
role,
limits,
fragmented_message_opcode: OpCode::Continuation,
payload_processed: 0,
validator: Validator::new(),
}
}
}
impl Encoder<Frame> for WebsocketProtocol {
type Error = Error;
fn encode(&mut self, _item: Frame, _dst: &mut BytesMut) -> Result<(), Self::Error> {
unsafe { unreachable_unchecked() }
}
}
macro_rules! ensure_buffer_has_space {
($buf:expr, $space:expr) => {
if $buf.len() < $space {
$buf.reserve(($space as usize).saturating_sub($buf.capacity()));
return Ok(None);
}
};
}
impl Decoder for WebsocketProtocol {
type Error = Error;
type Item = Frame;
#[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
ensure_buffer_has_space!(src, 2);
let fin_and_rsv = unsafe { src.get_unchecked(0) };
let payload_len_1 = unsafe { src.get_unchecked(1) };
let fin = fin_and_rsv >> 7 != 0;
let rsv = fin_and_rsv & 0x70;
if rsv != 0 {
return Err(Error::Protocol(ProtocolError::InvalidRsv));
}
let opcode = OpCode::try_from(fin_and_rsv & 0xF)?;
if opcode.is_control() {
if !fin {
return Err(Error::Protocol(ProtocolError::FragmentedControlFrame));
}
} else if self.fragmented_message_opcode == OpCode::Continuation {
if opcode == OpCode::Continuation {
return Err(Error::Protocol(ProtocolError::InvalidOpcode));
}
} else if opcode != OpCode::Continuation {
return Err(Error::Protocol(ProtocolError::InvalidOpcode));
}
let mask = payload_len_1 >> 7 != 0;
if mask && self.role == Role::Client {
return Err(Error::Protocol(ProtocolError::UnexpectedMaskedFrame));
} else if !mask && self.role == Role::Server {
return Err(Error::Protocol(ProtocolError::UnexpectedUnmaskedFrame));
}
let mut payload_length = (payload_len_1 & 127) as usize;
let mut offset = 2;
if opcode == OpCode::Close && payload_length == 1 {
return Err(Error::Protocol(ProtocolError::InvalidPayloadLength));
} else if payload_length > 125 {
if opcode.is_control() {
return Err(Error::Protocol(ProtocolError::InvalidPayloadLength));
}
if payload_length == 126 {
ensure_buffer_has_space!(src, offset + 2);
payload_length = u16::from_be_bytes(unsafe {
src.get_unchecked(2..4).try_into().unwrap_unchecked()
}) as usize;
if payload_length <= 125 {
return Err(Error::Protocol(ProtocolError::InvalidPayloadLength));
}
offset = 4;
} else if payload_length == 127 {
ensure_buffer_has_space!(src, offset + 8);
payload_length = u64::from_be_bytes(unsafe {
src.get_unchecked(2..10).try_into().unwrap_unchecked()
}) as usize;
if u16::try_from(payload_length).is_ok() {
return Err(Error::Protocol(ProtocolError::InvalidPayloadLength));
}
offset = 10;
} else {
unsafe { unreachable_unchecked() }
}
}
if payload_length > self.limits.max_payload_len.unwrap_or(usize::MAX) {
return Err(Error::PayloadTooLong {
len: payload_length,
max_len: self.limits.max_payload_len.unwrap_or(usize::MAX),
});
}
if mask {
ensure_buffer_has_space!(src, offset + 4);
offset += 4;
}
if payload_length != 0 {
let payload_available = (src.len() - offset).min(payload_length);
if payload_length != payload_available {
if opcode == OpCode::Text
|| (opcode == OpCode::Continuation
&& self.fragmented_message_opcode == OpCode::Text)
{
if mask {
let (masking_key, payload_masked) = unsafe {
let masking_key_ptr =
src.get_unchecked(offset - 4..offset) as *const [u8];
let payload_masked_ptr = src.get_unchecked_mut(
offset + self.payload_processed..offset + payload_available,
) as *mut [u8];
(&*masking_key_ptr, &mut *payload_masked_ptr)
};
mask::frame(masking_key, payload_masked, self.payload_processed & 3);
}
self.validator.feed(
unsafe {
src.get_unchecked(
offset + self.payload_processed..offset + payload_available,
)
},
false,
)?;
self.payload_processed = payload_available;
}
src.reserve((payload_length - payload_available).saturating_sub(src.capacity()));
return Ok(None);
}
if mask {
let (masking_key, payload_masked) = unsafe {
let masking_key_ptr = src.get_unchecked(offset - 4..offset) as *const [u8];
let payload_masked_ptr = src
.get_unchecked_mut(offset + self.payload_processed..offset + payload_length)
as *mut [u8];
(&*masking_key_ptr, &mut *payload_masked_ptr)
};
mask::frame(masking_key, payload_masked, self.payload_processed & 3);
}
if opcode == OpCode::Text
|| (opcode == OpCode::Continuation
&& self.fragmented_message_opcode == OpCode::Text)
{
self.validator.feed(
unsafe {
src.get_unchecked(offset + self.payload_processed..offset + payload_length)
},
fin,
)?;
} else if opcode == OpCode::Close {
let code = CloseCode::try_from(u16::from_be_bytes(unsafe {
src.get_unchecked(offset..offset + 2)
.try_into()
.unwrap_unchecked()
}))?;
if !code.is_sendable() {
return Err(Error::Protocol(ProtocolError::InvalidCloseCode));
}
let _reason = utf8::parse_str(unsafe {
src.get_unchecked(offset + 2..offset + payload_length)
})?;
}
}
src.advance(offset);
let payload = src.split_to(payload_length).into();
if fin && !opcode.is_control() {
self.fragmented_message_opcode = OpCode::Continuation;
} else if opcode != OpCode::Continuation && !opcode.is_control() {
self.fragmented_message_opcode = opcode;
}
self.payload_processed = 0;
Ok(Some(Frame {
opcode,
payload,
is_final: fin,
}))
}
}