use crate::error::{Error, ErrorCode, Result};
use crate::frame::{
CONNECTION_STREAM_ID, ContinuationFrame, DataFrame, FRAME_HEADER_SIZE, Frame, FrameFlags,
FrameHeader, FrameType, GoawayFrame, HeadersFrame, PingFrame, PriorityFields, PriorityFrame,
PriorityUpdateFrame, RstStreamFrame, SettingsFrame, WindowUpdateFrame,
};
use crate::settings::Setting;
#[derive(Debug)]
pub struct FrameDecoder {
max_frame_size: u32,
buf: Vec<u8>,
current_header: Option<FrameHeader>,
last_decoded_stream_id: Option<u32>,
}
impl FrameDecoder {
#[must_use]
pub fn new(max_frame_size: u32) -> Self {
Self {
max_frame_size,
buf: Vec::new(),
current_header: None,
last_decoded_stream_id: None,
}
}
pub fn set_max_frame_size(&mut self, max_frame_size: u32) {
self.max_frame_size = max_frame_size;
}
pub fn feed(&mut self, data: &[u8]) {
self.buf.extend_from_slice(data);
}
pub fn decode(&mut self) -> Result<Option<Frame>> {
if self.current_header.is_none() {
if self.buf.len() < FRAME_HEADER_SIZE {
return Ok(None);
}
let header = decode_header(&self.buf[..FRAME_HEADER_SIZE])?;
if header.length > self.max_frame_size {
return Err(Error::frame_size_error(format!(
"frame size {} exceeds max frame size {}",
header.length, self.max_frame_size
)));
}
self.current_header = Some(header);
self.buf.drain(..FRAME_HEADER_SIZE);
}
let header = self
.current_header
.as_ref()
.expect("current_header must be set");
let payload_len = header.length as usize;
if self.buf.len() < payload_len {
return Ok(None);
}
let header = self
.current_header
.take()
.expect("current_header must be set");
self.last_decoded_stream_id = Some(header.stream_id);
let payload: Vec<u8> = self.buf.drain(..payload_len).collect();
decode_frame(header, &payload).map(Some)
}
#[must_use]
pub fn buffered_len(&self) -> usize {
self.buf.len()
}
#[must_use]
pub fn last_decoded_stream_id(&self) -> Option<u32> {
self.last_decoded_stream_id
}
pub fn clear(&mut self) {
self.buf.clear();
self.current_header = None;
}
}
impl Default for FrameDecoder {
fn default() -> Self {
Self::new(crate::settings::DEFAULT_MAX_FRAME_SIZE)
}
}
pub fn decode_header(buf: &[u8]) -> Result<FrameHeader> {
if buf.len() < FRAME_HEADER_SIZE {
return Err(Error::incomplete());
}
let length = (u32::from(buf[0]) << 16) | (u32::from(buf[1]) << 8) | u32::from(buf[2]);
let frame_type = buf[3];
let flags = FrameFlags::from_bits(buf[4]);
let stream_id = ((u32::from(buf[5]) & 0x7f) << 24)
| (u32::from(buf[6]) << 16)
| (u32::from(buf[7]) << 8)
| u32::from(buf[8]);
Ok(FrameHeader {
length,
frame_type,
flags,
stream_id,
})
}
fn decode_frame(header: FrameHeader, payload: &[u8]) -> Result<Frame> {
match FrameType::from_u8(header.frame_type) {
Some(FrameType::Data) => decode_data(header, payload),
Some(FrameType::Headers) => decode_headers(header, payload),
Some(FrameType::Priority) => decode_priority(header, payload),
Some(FrameType::RstStream) => decode_rst_stream(header, payload),
Some(FrameType::Settings) => decode_settings(header, payload),
Some(FrameType::PushPromise) => decode_push_promise(header, payload),
Some(FrameType::Ping) => decode_ping(header, payload),
Some(FrameType::Goaway) => decode_goaway(header, payload),
Some(FrameType::WindowUpdate) => decode_window_update(header, payload),
Some(FrameType::Continuation) => decode_continuation(header, payload),
Some(FrameType::PriorityUpdate) => decode_priority_update(header, payload),
None => Ok(Frame::Unknown {
header,
payload: payload.to_vec(),
}),
}
}
fn decode_data(header: FrameHeader, payload: &[u8]) -> Result<Frame> {
if header.stream_id == CONNECTION_STREAM_ID {
return Err(Error::protocol_error("DATA frame with stream ID 0"));
}
let end_stream = header.flags.is_end_stream();
let padded = header.flags.is_padded();
let (data, pad_length) = if padded {
if payload.is_empty() {
return Err(Error::frame_size_error(
"PADDED flag set but no padding length",
));
}
let pad_len = payload[0] as usize;
if payload.len() < 1 + pad_len {
return Err(Error::protocol_error(
"padding length exceeds frame payload",
));
}
(
payload[1..payload.len() - pad_len].to_vec(),
Some(payload[0]),
)
} else {
(payload.to_vec(), None)
};
Ok(Frame::Data(DataFrame {
stream_id: header.stream_id,
end_stream,
data,
pad_length,
}))
}
fn decode_headers(header: FrameHeader, payload: &[u8]) -> Result<Frame> {
if header.stream_id == CONNECTION_STREAM_ID {
return Err(Error::protocol_error("HEADERS frame with stream ID 0"));
}
let end_stream = header.flags.is_end_stream();
let end_headers = header.flags.is_end_headers();
let padded = header.flags.is_padded();
let priority = header.flags.is_priority();
let mut offset = 0;
let mut data_end = payload.len();
let mut pad_length = None;
if padded {
if payload.is_empty() {
return Err(Error::frame_size_error(
"PADDED flag set but no padding length",
));
}
let pad_len = payload[0] as usize;
pad_length = Some(payload[0]);
offset = 1;
if payload.len() < 1 + pad_len {
return Err(Error::protocol_error(
"padding length exceeds frame payload",
));
}
data_end = payload.len() - pad_len;
}
let priority_fields = if priority {
if data_end - offset < 5 {
return Err(Error::frame_size_error(
"PRIORITY flag set but insufficient data for priority fields",
));
}
let exclusive = (payload[offset] & 0x80) != 0;
let stream_dependency = ((u32::from(payload[offset]) & 0x7f) << 24)
| (u32::from(payload[offset + 1]) << 16)
| (u32::from(payload[offset + 2]) << 8)
| u32::from(payload[offset + 3]);
let weight = payload[offset + 4];
offset += 5;
Some(PriorityFields {
exclusive,
stream_dependency,
weight,
})
} else {
None
};
let header_block_fragment = payload[offset..data_end].to_vec();
Ok(Frame::Headers(HeadersFrame {
stream_id: header.stream_id,
end_stream,
end_headers,
priority_fields,
header_block_fragment,
pad_length,
}))
}
fn decode_priority(header: FrameHeader, payload: &[u8]) -> Result<Frame> {
if header.stream_id == CONNECTION_STREAM_ID {
return Err(Error::protocol_error("PRIORITY frame with stream ID 0"));
}
if payload.len() != 5 {
return Err(Error::stream_error(
crate::error::ErrorCode::FrameSizeError,
"PRIORITY frame must be 5 bytes",
));
}
let exclusive = (payload[0] & 0x80) != 0;
let stream_dependency = ((u32::from(payload[0]) & 0x7f) << 24)
| (u32::from(payload[1]) << 16)
| (u32::from(payload[2]) << 8)
| u32::from(payload[3]);
let weight = payload[4];
Ok(Frame::Priority(PriorityFrame {
stream_id: header.stream_id,
exclusive,
stream_dependency,
weight,
}))
}
fn decode_rst_stream(header: FrameHeader, payload: &[u8]) -> Result<Frame> {
if header.stream_id == CONNECTION_STREAM_ID {
return Err(Error::protocol_error("RST_STREAM frame with stream ID 0"));
}
if payload.len() != 4 {
return Err(Error::frame_size_error("RST_STREAM frame must be 4 bytes"));
}
let error_code = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
Ok(Frame::RstStream(RstStreamFrame {
stream_id: header.stream_id,
error_code,
}))
}
fn decode_settings(header: FrameHeader, payload: &[u8]) -> Result<Frame> {
if header.stream_id != CONNECTION_STREAM_ID {
return Err(Error::protocol_error(
"SETTINGS frame with non-zero stream ID",
));
}
let ack = header.flags.is_ack();
if ack {
if !payload.is_empty() {
return Err(Error::frame_size_error(
"SETTINGS ACK frame with non-empty payload",
));
}
return Ok(Frame::Settings(SettingsFrame::ack()));
}
if !payload.len().is_multiple_of(6) {
return Err(Error::frame_size_error(
"SETTINGS frame payload must be a multiple of 6 bytes",
));
}
let mut settings = Vec::with_capacity(payload.len() / 6);
for chunk in payload.chunks(6) {
let id = u16::from_be_bytes([chunk[0], chunk[1]]);
let value = u32::from_be_bytes([chunk[2], chunk[3], chunk[4], chunk[5]]);
settings.push(Setting::new(id, value));
}
Ok(Frame::Settings(SettingsFrame {
ack: false,
settings,
}))
}
fn decode_ping(header: FrameHeader, payload: &[u8]) -> Result<Frame> {
if header.stream_id != CONNECTION_STREAM_ID {
return Err(Error::protocol_error("PING frame with non-zero stream ID"));
}
if payload.len() != 8 {
return Err(Error::frame_size_error("PING frame must be 8 bytes"));
}
let ack = header.flags.is_ack();
let mut opaque_data = [0u8; 8];
opaque_data.copy_from_slice(payload);
Ok(Frame::Ping(PingFrame { ack, opaque_data }))
}
fn decode_goaway(header: FrameHeader, payload: &[u8]) -> Result<Frame> {
if header.stream_id != CONNECTION_STREAM_ID {
return Err(Error::protocol_error(
"GOAWAY frame with non-zero stream ID",
));
}
if payload.len() < 8 {
return Err(Error::frame_size_error(
"GOAWAY frame must be at least 8 bytes",
));
}
let last_stream_id = ((u32::from(payload[0]) & 0x7f) << 24)
| (u32::from(payload[1]) << 16)
| (u32::from(payload[2]) << 8)
| u32::from(payload[3]);
let error_code = u32::from_be_bytes([payload[4], payload[5], payload[6], payload[7]]);
let debug_data = if payload.len() > 8 {
payload[8..].to_vec()
} else {
Vec::new()
};
Ok(Frame::Goaway(GoawayFrame {
last_stream_id,
error_code,
debug_data,
}))
}
fn decode_window_update(header: FrameHeader, payload: &[u8]) -> Result<Frame> {
if payload.len() != 4 {
return Err(Error::frame_size_error(
"WINDOW_UPDATE frame must be 4 bytes",
));
}
let window_size_increment = ((u32::from(payload[0]) & 0x7f) << 24)
| (u32::from(payload[1]) << 16)
| (u32::from(payload[2]) << 8)
| u32::from(payload[3]);
if window_size_increment == 0 {
if header.stream_id == CONNECTION_STREAM_ID {
return Err(Error::connection_error(
ErrorCode::ProtocolError,
"WINDOW_UPDATE with zero increment on connection",
));
}
return Err(Error::stream_error(
ErrorCode::ProtocolError,
"WINDOW_UPDATE with zero increment on stream",
));
}
Ok(Frame::WindowUpdate(WindowUpdateFrame {
stream_id: header.stream_id,
window_size_increment,
}))
}
fn decode_continuation(header: FrameHeader, payload: &[u8]) -> Result<Frame> {
if header.stream_id == CONNECTION_STREAM_ID {
return Err(Error::protocol_error("CONTINUATION frame with stream ID 0"));
}
let end_headers = header.flags.is_end_headers();
let header_block_fragment = payload.to_vec();
Ok(Frame::Continuation(ContinuationFrame {
stream_id: header.stream_id,
end_headers,
header_block_fragment,
}))
}
fn decode_push_promise(header: FrameHeader, _payload: &[u8]) -> Result<Frame> {
if header.stream_id == CONNECTION_STREAM_ID {
return Err(Error::protocol_error("PUSH_PROMISE frame with stream ID 0"));
}
Ok(Frame::PushPromise {
stream_id: header.stream_id,
})
}
fn decode_priority_update(header: FrameHeader, payload: &[u8]) -> Result<Frame> {
if header.stream_id != CONNECTION_STREAM_ID {
return Err(Error::protocol_error(
"PRIORITY_UPDATE frame with non-zero stream ID",
));
}
if payload.len() < 4 {
return Err(Error::frame_size_error(
"PRIORITY_UPDATE frame must be at least 4 bytes",
));
}
let prioritized_element_id = ((u32::from(payload[0]) & 0x7f) << 24)
| (u32::from(payload[1]) << 16)
| (u32::from(payload[2]) << 8)
| u32::from(payload[3]);
let priority_field_value = if payload.len() > 4 {
payload[4..].to_vec()
} else {
Vec::new()
};
Ok(Frame::PriorityUpdate(PriorityUpdateFrame {
prioritized_element_id,
priority_field_value,
}))
}