use crate::buffer::{BufferPoolRef, OwnedReadBuffer};
use crate::util::into_array;
use crate::ws::{Error, WebsocketFrame, protocol};
use std::io;
use std::io::Read;
#[derive(Debug)]
pub struct Decoder {
buffer: OwnedReadBuffer<4096>,
decode_state: DecodeState,
fin: bool,
payload_length: usize,
op_code: u8,
needs_more_data: bool,
}
#[derive(Debug)]
enum DecodeState {
ReadingHeader,
ReadingPayloadLength,
ReadingExtendedPayloadLength2,
ReadingExtendedPayloadLength8,
ReadingPayload,
}
impl Decoder {
pub fn new(pool: &mut BufferPoolRef) -> Self {
Self {
buffer: pool.acquire(),
decode_state: DecodeState::ReadingHeader,
fin: false,
op_code: 0,
payload_length: 0,
needs_more_data: true,
}
}
#[inline]
pub fn read<S: Read>(&mut self, stream: &mut S) -> io::Result<()> {
if self.needs_more_data {
self.buffer.read_all_from(stream)?;
self.needs_more_data = false;
}
Ok(())
}
#[inline]
pub fn decode_next(&mut self) -> Result<Option<WebsocketFrame>, Error> {
loop {
let available = self.buffer.available();
match self.decode_state {
DecodeState::ReadingHeader => {
if available > 0 {
let b = unsafe { self.buffer.consume_next_byte_unchecked() };
let fin = ((b & protocol::FIN_MASK) >> 7) == 1;
let rsv1 = (b & protocol::RSV1_MASK) >> 6;
let rsv2 = (b & protocol::RSV2_MASK) >> 5;
let rsv3 = (b & protocol::RSV3_MASK) >> 4;
if rsv1 + rsv2 + rsv3 != 0 {
return Err(Error::Protocol("non zero RSV value received".to_string()));
}
self.fin = fin;
let op_code = b & protocol::OP_CODE_MASK;
self.op_code = op_code;
self.decode_state = DecodeState::ReadingPayloadLength
} else {
break;
}
}
DecodeState::ReadingPayloadLength => {
if available > 0 {
let b = unsafe { self.buffer.consume_next_byte_unchecked() };
let mask = (b & protocol::MASK_MASK) >> 7;
if mask == 1 {
return Err(Error::Protocol("masking bit set on the server frame".to_string()));
}
let payload_length = b & protocol::PAYLOAD_LENGTH_MASK;
self.payload_length = payload_length as usize;
match payload_length {
0..=125 => self.decode_state = DecodeState::ReadingPayload,
126 => self.decode_state = DecodeState::ReadingExtendedPayloadLength2,
127 => self.decode_state = DecodeState::ReadingExtendedPayloadLength8,
_ => unsafe { std::hint::unreachable_unchecked() },
}
} else {
break;
}
}
DecodeState::ReadingExtendedPayloadLength2 => {
if available >= 2 {
let bytes = unsafe { self.buffer.consume_next_unchecked(2) };
let payload_length = u16::from_be_bytes(unsafe { into_array(bytes) });
self.payload_length = payload_length as usize;
self.decode_state = DecodeState::ReadingPayload;
} else {
break;
}
}
DecodeState::ReadingExtendedPayloadLength8 => {
if available >= 8 {
let bytes = unsafe { self.buffer.consume_next_unchecked(8) };
let payload_length = u64::from_be_bytes(unsafe { into_array(bytes) });
self.payload_length = payload_length as usize;
self.decode_state = DecodeState::ReadingPayload;
} else {
break;
}
}
DecodeState::ReadingPayload => {
let payload_length = self.payload_length;
if available >= payload_length {
let payload = unsafe { self.buffer.consume_next_unchecked(payload_length) };
let frame = match self.op_code {
protocol::op::TEXT_FRAME => WebsocketFrame::Text(self.fin, payload),
protocol::op::BINARY_FRAME => WebsocketFrame::Binary(self.fin, payload),
protocol::op::CONTINUATION_FRAME => WebsocketFrame::Continuation(self.fin, payload),
protocol::op::PING => WebsocketFrame::Ping(payload),
protocol::op::PONG => WebsocketFrame::Pong(payload),
protocol::op::CONNECTION_CLOSE => WebsocketFrame::Close(payload),
_ => return Err(Error::Protocol(format!("unknown op_code: {}", self.op_code))),
};
self.decode_state = DecodeState::ReadingHeader;
return Ok(Some(frame));
} else {
break;
}
}
}
}
self.needs_more_data = true;
Ok(None)
}
}