use super::{H2ErrorCode, H2Settings};
pub(crate) mod continuation;
pub(crate) mod data;
pub(crate) mod goaway;
pub(crate) mod headers;
pub(crate) mod ping;
pub(crate) mod priority;
pub(crate) mod rst_stream;
pub(crate) mod settings;
pub(crate) mod window_update;
pub(crate) const FRAME_HEADER_LEN: usize = 9;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub(crate) enum FrameType {
Data = 0x0,
Headers = 0x1,
Priority = 0x2,
RstStream = 0x3,
Settings = 0x4,
PushPromise = 0x5,
Ping = 0x6,
Goaway = 0x7,
WindowUpdate = 0x8,
Continuation = 0x9,
}
impl TryFrom<u8> for FrameType {
type Error = u8;
fn try_from(value: u8) -> Result<Self, u8> {
match value {
0x0 => Ok(Self::Data),
0x1 => Ok(Self::Headers),
0x2 => Ok(Self::Priority),
0x3 => Ok(Self::RstStream),
0x4 => Ok(Self::Settings),
0x5 => Ok(Self::PushPromise),
0x6 => Ok(Self::Ping),
0x7 => Ok(Self::Goaway),
0x8 => Ok(Self::WindowUpdate),
0x9 => Ok(Self::Continuation),
other => Err(other),
}
}
}
pub(crate) const FLAG_END_STREAM: u8 = 0x01;
pub(crate) const FLAG_ACK: u8 = 0x01;
pub(crate) const FLAG_END_HEADERS: u8 = 0x04;
pub(crate) const FLAG_PADDED: u8 = 0x08;
pub(crate) const FLAG_PRIORITY: u8 = 0x20;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct FrameHeader {
pub(crate) length: u32,
pub(crate) frame_type: u8,
pub(crate) flags: u8,
pub(crate) stream_id: u32,
}
impl FrameHeader {
pub(crate) fn decode(input: &[u8]) -> Option<Self> {
if input.len() < FRAME_HEADER_LEN {
return None;
}
let length = u32::from_be_bytes([0, input[0], input[1], input[2]]);
let frame_type = input[3];
let flags = input[4];
let stream_id = u32::from_be_bytes([input[5], input[6], input[7], input[8]]) & 0x7FFF_FFFF;
Some(Self {
length,
frame_type,
flags,
stream_id,
})
}
pub(crate) fn encode(&self, buf: &mut [u8]) {
debug_assert!(
buf.len() >= FRAME_HEADER_LEN,
"frame header buffer too small"
);
debug_assert!(self.length < (1 << 24), "payload length exceeds 24 bits");
debug_assert!(self.stream_id < (1 << 31), "stream id exceeds 31 bits");
let length = self.length.to_be_bytes();
buf[0] = length[1];
buf[1] = length[2];
buf[2] = length[3];
buf[3] = self.frame_type;
buf[4] = self.flags;
buf[5..9].copy_from_slice(&(self.stream_id & 0x7FFF_FFFF).to_be_bytes());
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum FrameDecodeError {
Incomplete,
Error(H2ErrorCode),
}
impl From<H2ErrorCode> for FrameDecodeError {
fn from(code: H2ErrorCode) -> Self {
Self::Error(code)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum Frame {
Data {
stream_id: u32,
end_stream: bool,
data_length: u32,
padding_length: u8,
},
Headers {
stream_id: u32,
end_stream: bool,
end_headers: bool,
priority: Option<PriorityInfo>,
header_block_length: u32,
padding_length: u8,
},
Priority {
stream_id: u32,
priority: PriorityInfo,
},
RstStream {
stream_id: u32,
error_code: H2ErrorCode,
},
Settings(H2Settings),
SettingsAck,
PushPromise {
stream_id: u32,
length: u32,
},
Ping {
opaque_data: [u8; 8],
ack: bool,
},
Goaway {
last_stream_id: u32,
error_code: H2ErrorCode,
debug_data_length: u32,
},
WindowUpdate {
stream_id: u32,
increment: u32,
},
Continuation {
stream_id: u32,
end_headers: bool,
header_block_length: u32,
},
Unknown {
stream_id: u32,
frame_type: u8,
flags: u8,
length: u32,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct PriorityInfo {
pub(crate) exclusive: bool,
pub(crate) stream_dependency: u32,
pub(crate) weight: u16,
}
impl PriorityInfo {
pub(crate) const WIRE_LEN: u32 = 5;
pub(crate) fn decode(input: &[u8]) -> Self {
debug_assert!(input.len() >= Self::WIRE_LEN as usize);
let dep_word = u32::from_be_bytes([input[0], input[1], input[2], input[3]]);
Self {
exclusive: dep_word & 0x8000_0000 != 0,
stream_dependency: dep_word & 0x7FFF_FFFF,
weight: u16::from(input[4]) + 1,
}
}
}
impl Frame {
pub(crate) fn decode(input: &[u8]) -> Result<(Self, usize), FrameDecodeError> {
let header = FrameHeader::decode(input).ok_or(FrameDecodeError::Incomplete)?;
let prefix_input = || {
input
.get(FRAME_HEADER_LEN..)
.ok_or(FrameDecodeError::Incomplete)
};
match FrameType::try_from(header.frame_type) {
Ok(FrameType::Data) => {
let (frame, prefix_consumed) = data::decode_prefix(header, prefix_input()?)?;
Ok((frame, FRAME_HEADER_LEN + prefix_consumed))
}
Ok(FrameType::Headers) => {
let (frame, prefix_consumed) = headers::decode_prefix(header, prefix_input()?)?;
Ok((frame, FRAME_HEADER_LEN + prefix_consumed))
}
Ok(FrameType::Continuation) => {
continuation::decode(header).map(|f| (f, FRAME_HEADER_LEN))
}
Ok(FrameType::PushPromise) => Ok((
Frame::PushPromise {
stream_id: header.stream_id,
length: header.length,
},
FRAME_HEADER_LEN,
)),
Ok(FrameType::Priority) => {
let payload = require_payload(input, header)?;
priority::decode(header, payload).map(|f| (f, FRAME_HEADER_LEN + payload.len()))
}
Ok(FrameType::RstStream) => {
let payload = require_payload(input, header)?;
rst_stream::decode(header, payload).map(|f| (f, FRAME_HEADER_LEN + payload.len()))
}
Ok(FrameType::Settings) => {
let payload = require_payload(input, header)?;
settings::decode(header, payload).map(|f| (f, FRAME_HEADER_LEN + payload.len()))
}
Ok(FrameType::Ping) => {
let payload = require_payload(input, header)?;
ping::decode(header, payload).map(|f| (f, FRAME_HEADER_LEN + payload.len()))
}
Ok(FrameType::Goaway) => {
let payload = require_payload(input, header)?;
goaway::decode(header, payload).map(|f| (f, FRAME_HEADER_LEN + payload.len()))
}
Ok(FrameType::WindowUpdate) => {
let payload = require_payload(input, header)?;
window_update::decode(header, payload)
.map(|f| (f, FRAME_HEADER_LEN + payload.len()))
}
Err(frame_type) => Ok((
Frame::Unknown {
stream_id: header.stream_id,
frame_type,
flags: header.flags,
length: header.length,
},
FRAME_HEADER_LEN,
)),
}
}
}
pub(crate) fn require_payload(
input: &[u8],
header: FrameHeader,
) -> Result<&[u8], FrameDecodeError> {
let length = usize::try_from(header.length).map_err(|_| H2ErrorCode::FrameSizeError)?;
input
.get(FRAME_HEADER_LEN..FRAME_HEADER_LEN + length)
.ok_or(FrameDecodeError::Incomplete)
}
#[cfg(test)]
pub(crate) fn encode_frame(
frame_type: FrameType,
flags: u8,
stream_id: u32,
payload: &[u8],
) -> Vec<u8> {
let mut buf = vec![0u8; FRAME_HEADER_LEN + payload.len()];
FrameHeader {
length: u32::try_from(payload.len()).unwrap(),
frame_type: frame_type as u8,
flags,
stream_id,
}
.encode(&mut buf);
buf[FRAME_HEADER_LEN..].copy_from_slice(payload);
buf
}
#[cfg(test)]
mod tests {
#![allow(clippy::cast_possible_truncation)]
use super::*;
#[test]
fn frame_header_roundtrip() {
let header = FrameHeader {
length: 0x00_01_02_03 & 0x00FF_FFFF,
frame_type: 0x09,
flags: 0x0F,
stream_id: 0x1234_5678,
};
let mut buf = [0u8; FRAME_HEADER_LEN];
header.encode(&mut buf);
let decoded = FrameHeader::decode(&buf).unwrap();
assert_eq!(decoded, header);
}
#[test]
fn frame_header_masks_reserved_bit_on_decode() {
let mut buf = [0u8; FRAME_HEADER_LEN];
buf[2] = 1;
buf[3] = 0x06;
buf[5..9].copy_from_slice(&0xFFFF_FFFFu32.to_be_bytes());
let decoded = FrameHeader::decode(&buf).unwrap();
assert_eq!(decoded.stream_id, 0x7FFF_FFFF);
}
#[test]
fn frame_header_incomplete() {
assert!(FrameHeader::decode(&[0u8; FRAME_HEADER_LEN - 1]).is_none());
}
#[test]
fn unknown_frame_type_returns_unknown_variant() {
let payload = [1u8, 2, 3];
let mut buf = vec![0u8; FRAME_HEADER_LEN + payload.len()];
FrameHeader {
length: u32::try_from(payload.len()).unwrap(),
frame_type: 0xBE,
flags: 0xEF,
stream_id: 5,
}
.encode((&mut buf[..FRAME_HEADER_LEN]).try_into().unwrap());
buf[FRAME_HEADER_LEN..].copy_from_slice(&payload);
let (frame, consumed) = Frame::decode(&buf).unwrap();
assert_eq!(consumed, FRAME_HEADER_LEN);
assert_eq!(
frame,
Frame::Unknown {
stream_id: 5,
frame_type: 0xBE,
flags: 0xEF,
length: 3,
}
);
}
#[test]
fn push_promise_variant_surfaced_for_rejection() {
let payload = [0u8; 8]; let mut buf = vec![0u8; FRAME_HEADER_LEN + payload.len()];
FrameHeader {
length: u32::try_from(payload.len()).unwrap(),
frame_type: FrameType::PushPromise as u8,
flags: 0,
stream_id: 1,
}
.encode((&mut buf[..FRAME_HEADER_LEN]).try_into().unwrap());
buf[FRAME_HEADER_LEN..].copy_from_slice(&payload);
let (frame, consumed) = Frame::decode(&buf).unwrap();
assert_eq!(consumed, FRAME_HEADER_LEN);
assert_eq!(
frame,
Frame::PushPromise {
stream_id: 1,
length: 8,
}
);
}
#[test]
fn incomplete_header_is_incomplete() {
assert_eq!(Frame::decode(&[0u8; 4]), Err(FrameDecodeError::Incomplete));
}
#[test]
fn incomplete_control_payload_is_incomplete() {
let mut buf = vec![0u8; FRAME_HEADER_LEN + 4];
FrameHeader {
length: 8,
frame_type: FrameType::Ping as u8,
flags: 0,
stream_id: 0,
}
.encode((&mut buf[..FRAME_HEADER_LEN]).try_into().unwrap());
assert_eq!(Frame::decode(&buf), Err(FrameDecodeError::Incomplete));
}
}