use crate::bytes::{BufMut, Bytes, BytesMut};
use super::error::{ErrorCode, H2Error};
pub const FRAME_HEADER_SIZE: usize = 9;
pub const DEFAULT_MAX_FRAME_SIZE: u32 = 16_384;
pub const MAX_FRAME_SIZE: u32 = 16_777_215;
pub const MIN_MAX_FRAME_SIZE: u32 = 16_384;
const MAX_INITIAL_WINDOW_SIZE: u32 = 0x7fff_ffff;
const MAX_WINDOW_UPDATE_INCREMENT: u32 = 0x7fff_ffff;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u8)]
pub enum FrameType {
Data = 0x0,
Headers = 0x1,
Priority = 0x2,
RstStream = 0x3,
Settings = 0x4,
PushPromise = 0x5,
Ping = 0x6,
GoAway = 0x7,
WindowUpdate = 0x8,
Continuation = 0x9,
}
impl FrameType {
#[must_use]
pub fn from_u8(value: u8) -> Option<Self> {
match value {
0x0 => Some(Self::Data),
0x1 => Some(Self::Headers),
0x2 => Some(Self::Priority),
0x3 => Some(Self::RstStream),
0x4 => Some(Self::Settings),
0x5 => Some(Self::PushPromise),
0x6 => Some(Self::Ping),
0x7 => Some(Self::GoAway),
0x8 => Some(Self::WindowUpdate),
0x9 => Some(Self::Continuation),
_ => None,
}
}
}
pub mod data_flags {
pub const END_STREAM: u8 = 0x1;
pub const PADDED: u8 = 0x8;
}
pub mod headers_flags {
pub const END_STREAM: u8 = 0x1;
pub const END_HEADERS: u8 = 0x4;
pub const PADDED: u8 = 0x8;
pub const PRIORITY: u8 = 0x20;
}
pub mod settings_flags {
pub const ACK: u8 = 0x1;
}
pub mod ping_flags {
pub const ACK: u8 = 0x1;
}
pub mod continuation_flags {
pub const END_HEADERS: u8 = 0x4;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FrameHeader {
pub length: u32,
pub frame_type: u8,
pub flags: u8,
pub stream_id: u32,
}
impl FrameHeader {
pub fn parse(src: &mut BytesMut) -> Result<Self, H2Error> {
if src.len() < FRAME_HEADER_SIZE {
return Err(H2Error::protocol("insufficient bytes for frame header"));
}
let length = ((u32::from(src[0])) << 16) | ((u32::from(src[1])) << 8) | u32::from(src[2]);
let frame_type = src[3];
let flags = src[4];
let stream_id = ((u32::from(src[5]) & 0x7f) << 24)
| ((u32::from(src[6])) << 16)
| ((u32::from(src[7])) << 8)
| u32::from(src[8]);
let _ = src.split_to(FRAME_HEADER_SIZE);
Ok(Self {
length,
frame_type,
flags,
stream_id,
})
}
#[inline]
pub fn write(&self, dst: &mut BytesMut) {
assert!(
self.length <= MAX_FRAME_SIZE,
"frame length {0} exceeds 24-bit max {1}",
self.length,
MAX_FRAME_SIZE,
);
let buf: [u8; FRAME_HEADER_SIZE] = [
(self.length >> 16) as u8,
(self.length >> 8) as u8,
self.length as u8,
self.frame_type,
self.flags,
((self.stream_id >> 24) & 0x7f) as u8,
(self.stream_id >> 16) as u8,
(self.stream_id >> 8) as u8,
self.stream_id as u8,
];
dst.extend_from_slice(&buf);
}
#[must_use]
pub fn has_flag(&self, flag: u8) -> bool {
self.flags & flag != 0
}
}
#[inline]
fn try_frame_length(len: usize) -> Result<u32, H2Error> {
if len > MAX_FRAME_SIZE as usize {
return Err(H2Error::frame_size(format!(
"payload length {len} exceeds 24-bit max {MAX_FRAME_SIZE} (br-asupersync-pt23uf)"
)));
}
#[allow(clippy::cast_possible_truncation)]
let len = len as u32;
Ok(len)
}
#[inline]
fn checked_frame_length_sum(lhs: usize, rhs: usize) -> Result<u32, H2Error> {
let len = lhs.checked_add(rhs).ok_or_else(|| {
H2Error::frame_size(format!(
"payload length overflow while adding {lhs} and {rhs} (br-asupersync-pt23uf)"
))
})?;
try_frame_length(len)
}
#[inline]
fn checked_frame_length_product(lhs: usize, rhs: usize) -> Result<u32, H2Error> {
let len = lhs.checked_mul(rhs).ok_or_else(|| {
H2Error::frame_size(format!(
"payload length overflow while multiplying {lhs} by {rhs} (br-asupersync-pt23uf)"
))
})?;
try_frame_length(len)
}
#[derive(Debug, Clone)]
pub enum Frame {
Data(DataFrame),
Headers(HeadersFrame),
Priority(PriorityFrame),
RstStream(RstStreamFrame),
Settings(SettingsFrame),
PushPromise(PushPromiseFrame),
Ping(PingFrame),
GoAway(GoAwayFrame),
WindowUpdate(WindowUpdateFrame),
Continuation(ContinuationFrame),
Unknown {
frame_type: u8,
stream_id: u32,
payload: Bytes,
},
}
impl Frame {
#[must_use]
pub fn stream_id(&self) -> u32 {
match self {
Self::Data(f) => f.stream_id,
Self::Headers(f) => f.stream_id,
Self::Priority(f) => f.stream_id,
Self::RstStream(f) => f.stream_id,
Self::Settings(_) | Self::Ping(_) | Self::GoAway(_) => 0,
Self::PushPromise(f) => f.stream_id,
Self::WindowUpdate(f) => f.stream_id,
Self::Continuation(f) => f.stream_id,
Self::Unknown { stream_id, .. } => *stream_id,
}
}
#[inline]
pub fn encode(&self, dst: &mut BytesMut) -> Result<(), H2Error> {
match self {
Self::Data(f) => f.encode(dst),
Self::Headers(f) => f.encode(dst),
Self::Priority(f) => f.encode(dst),
Self::RstStream(f) => f.encode(dst),
Self::Settings(f) => f.encode(dst),
Self::PushPromise(f) => f.encode(dst),
Self::Ping(f) => f.encode(dst),
Self::GoAway(f) => f.encode(dst),
Self::WindowUpdate(f) => f.encode(dst),
Self::Continuation(f) => f.encode(dst),
Self::Unknown {
frame_type,
stream_id,
payload,
} => {
let header = FrameHeader {
length: try_frame_length(payload.len())?,
frame_type: *frame_type,
flags: 0,
stream_id: *stream_id,
};
header.write(dst);
dst.extend_from_slice(payload);
Ok(())
}
}
}
}
#[derive(Debug, Clone)]
pub struct DataFrame {
pub stream_id: u32,
pub data: Bytes,
pub end_stream: bool,
}
impl DataFrame {
#[must_use]
pub fn new(stream_id: u32, data: Bytes, end_stream: bool) -> Self {
Self {
stream_id,
data,
end_stream,
}
}
pub fn parse(header: &FrameHeader, payload: Bytes) -> Result<Self, H2Error> {
if header.stream_id == 0 {
return Err(H2Error::protocol("DATA frame with stream ID 0"));
}
let mut data = payload;
let end_stream = header.has_flag(data_flags::END_STREAM);
if header.has_flag(data_flags::PADDED) {
if data.is_empty() {
return Err(H2Error::protocol(
"PADDED DATA frame with no padding length",
));
}
let pad_length = data[0] as usize;
data = data.slice(1..);
if pad_length > data.len() {
return Err(H2Error::protocol("DATA frame padding exceeds data length"));
}
data = data.slice(..data.len() - pad_length);
}
Ok(Self {
stream_id: header.stream_id,
data,
end_stream,
})
}
#[inline]
pub fn encode(&self, dst: &mut BytesMut) -> Result<(), H2Error> {
if self.stream_id == 0 {
return Err(H2Error::protocol("DATA frame with stream ID 0"));
}
let mut flags = 0u8;
if self.end_stream {
flags |= data_flags::END_STREAM;
}
let header = FrameHeader {
length: try_frame_length(self.data.len())?,
frame_type: FrameType::Data as u8,
flags,
stream_id: self.stream_id,
};
header.write(dst);
dst.extend_from_slice(&self.data);
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct HeadersFrame {
pub stream_id: u32,
pub header_block: Bytes,
pub end_stream: bool,
pub end_headers: bool,
pub priority: Option<PrioritySpec>,
}
impl HeadersFrame {
#[must_use]
pub fn new(stream_id: u32, header_block: Bytes, end_stream: bool, end_headers: bool) -> Self {
Self {
stream_id,
header_block,
end_stream,
end_headers,
priority: None,
}
}
pub fn parse(header: &FrameHeader, mut payload: Bytes) -> Result<Self, H2Error> {
if header.stream_id == 0 {
return Err(H2Error::protocol("HEADERS frame with stream ID 0"));
}
let end_stream = header.has_flag(headers_flags::END_STREAM);
let end_headers = header.has_flag(headers_flags::END_HEADERS);
let padded = header.has_flag(headers_flags::PADDED);
let has_priority = header.has_flag(headers_flags::PRIORITY);
let mut pad_length = 0;
if padded {
if payload.is_empty() {
return Err(H2Error::protocol(
"PADDED HEADERS frame with no padding length",
));
}
pad_length = payload[0] as usize;
payload = payload.slice(1..);
let priority_bytes = if has_priority { 5usize } else { 0 };
if pad_length.saturating_add(priority_bytes) > payload.len() {
return Err(H2Error::protocol(
"HEADERS frame padding exceeds data length",
));
}
}
let priority = if has_priority {
if payload.len() < 5 {
return Err(H2Error::protocol("HEADERS frame too short for priority"));
}
let exclusive = payload[0] & 0x80 != 0;
let dependency = ((u32::from(payload[0]) & 0x7f) << 24)
| ((u32::from(payload[1])) << 16)
| ((u32::from(payload[2])) << 8)
| u32::from(payload[3]);
if dependency == header.stream_id {
return Err(H2Error::stream(
header.stream_id,
ErrorCode::ProtocolError,
"stream cannot depend on itself",
));
}
let weight = payload[4];
payload = payload.slice(5..);
Some(PrioritySpec {
exclusive,
dependency,
weight,
})
} else {
None
};
if padded {
payload = payload.slice(..payload.len() - pad_length);
}
Ok(Self {
stream_id: header.stream_id,
header_block: payload,
end_stream,
end_headers,
priority,
})
}
#[inline]
pub fn encode(&self, dst: &mut BytesMut) -> Result<(), H2Error> {
if self.stream_id == 0 {
return Err(H2Error::protocol("HEADERS frame with stream ID 0"));
}
if self
.priority
.is_some_and(|priority| priority.dependency & 0x7fff_ffff == self.stream_id)
{
return Err(H2Error::stream(
self.stream_id,
ErrorCode::ProtocolError,
"stream cannot depend on itself",
));
}
let mut flags = 0u8;
if self.end_stream {
flags |= headers_flags::END_STREAM;
}
if self.end_headers {
flags |= headers_flags::END_HEADERS;
}
let length = if self.priority.is_some() {
flags |= headers_flags::PRIORITY;
checked_frame_length_sum(self.header_block.len(), 5)?
} else {
try_frame_length(self.header_block.len())?
};
let header = FrameHeader {
length,
frame_type: FrameType::Headers as u8,
flags,
stream_id: self.stream_id,
};
header.write(dst);
if let Some(ref priority) = self.priority {
let mut dep = priority.dependency;
if priority.exclusive {
dep |= 0x8000_0000;
}
dst.put_u32(dep);
dst.put_u8(priority.weight);
}
dst.extend_from_slice(&self.header_block);
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PrioritySpec {
pub exclusive: bool,
pub dependency: u32,
pub weight: u8,
}
#[derive(Debug, Clone, Copy)]
pub struct PriorityFrame {
pub stream_id: u32,
pub priority: PrioritySpec,
}
impl PriorityFrame {
pub fn parse(header: &FrameHeader, payload: &Bytes) -> Result<Self, H2Error> {
if header.stream_id == 0 {
return Err(H2Error::protocol("PRIORITY frame with stream ID 0"));
}
if payload.len() != 5 {
return Err(H2Error::stream(
header.stream_id,
ErrorCode::FrameSizeError,
"PRIORITY frame must be 5 bytes",
));
}
let exclusive = payload[0] & 0x80 != 0;
let dependency = ((u32::from(payload[0]) & 0x7f) << 24)
| ((u32::from(payload[1])) << 16)
| ((u32::from(payload[2])) << 8)
| u32::from(payload[3]);
if dependency == header.stream_id {
return Err(H2Error::stream(
header.stream_id,
ErrorCode::ProtocolError,
"stream cannot depend on itself",
));
}
let weight = payload[4];
Ok(Self {
stream_id: header.stream_id,
priority: PrioritySpec {
exclusive,
dependency,
weight,
},
})
}
#[inline]
pub fn encode(&self, dst: &mut BytesMut) -> Result<(), H2Error> {
if self.stream_id == 0 {
return Err(H2Error::protocol("PRIORITY frame with stream ID 0"));
}
if self.priority.dependency & 0x7fff_ffff == self.stream_id {
return Err(H2Error::stream(
self.stream_id,
ErrorCode::ProtocolError,
"stream cannot depend on itself",
));
}
let header = FrameHeader {
length: 5,
frame_type: FrameType::Priority as u8,
flags: 0,
stream_id: self.stream_id,
};
header.write(dst);
let mut dep = self.priority.dependency;
if self.priority.exclusive {
dep |= 0x8000_0000;
}
dst.put_u32(dep);
dst.put_u8(self.priority.weight);
Ok(())
}
}
#[derive(Debug, Clone, Copy)]
pub struct RstStreamFrame {
pub stream_id: u32,
pub error_code: ErrorCode,
}
impl RstStreamFrame {
#[must_use]
pub fn new(stream_id: u32, error_code: ErrorCode) -> Self {
Self {
stream_id,
error_code,
}
}
pub fn parse(header: &FrameHeader, payload: &Bytes) -> Result<Self, H2Error> {
if header.stream_id == 0 {
return Err(H2Error::protocol("RST_STREAM frame with stream ID 0"));
}
if payload.len() != 4 {
return Err(H2Error::frame_size("RST_STREAM frame must be 4 bytes"));
}
let error_code = ErrorCode::from_u32(
((u32::from(payload[0])) << 24)
| ((u32::from(payload[1])) << 16)
| ((u32::from(payload[2])) << 8)
| u32::from(payload[3]),
);
Ok(Self {
stream_id: header.stream_id,
error_code,
})
}
#[inline]
pub fn encode(&self, dst: &mut BytesMut) -> Result<(), H2Error> {
if self.stream_id == 0 {
return Err(H2Error::protocol("RST_STREAM frame with stream ID 0"));
}
let header = FrameHeader {
length: 4,
frame_type: FrameType::RstStream as u8,
flags: 0,
stream_id: self.stream_id,
};
header.write(dst);
dst.put_u32(self.error_code.into());
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct SettingsFrame {
pub settings: Vec<Setting>,
pub ack: bool,
}
impl SettingsFrame {
#[must_use]
pub fn new(settings: Vec<Setting>) -> Self {
Self {
settings,
ack: false,
}
}
#[must_use]
pub fn ack() -> Self {
Self {
settings: Vec::new(),
ack: true,
}
}
pub fn parse(header: &FrameHeader, payload: &Bytes) -> Result<Self, H2Error> {
if header.stream_id != 0 {
return Err(H2Error::protocol("SETTINGS frame with non-zero stream ID"));
}
let ack = header.has_flag(settings_flags::ACK);
if ack && !payload.is_empty() {
return Err(H2Error::frame_size("SETTINGS ACK with non-zero length"));
}
if !payload.len().is_multiple_of(6) {
return Err(H2Error::frame_size(
"SETTINGS frame length not multiple of 6",
));
}
let mut settings = Vec::new();
let mut cursor: usize = 0;
while cursor.saturating_add(6) <= payload.len() {
let id =
((u16::from(payload[cursor])) << 8) | u16::from(payload[cursor.saturating_add(1)]);
let value = ((u32::from(payload[cursor.saturating_add(2)])) << 24)
| ((u32::from(payload[cursor.saturating_add(3)])) << 16)
| ((u32::from(payload[cursor.saturating_add(4)])) << 8)
| u32::from(payload[cursor.saturating_add(5)]);
if id == 0x2 && value > 1 {
return Err(H2Error::protocol("SETTINGS_ENABLE_PUSH must be 0 or 1"));
}
if id == 0x4 && value > MAX_INITIAL_WINDOW_SIZE {
return Err(H2Error::flow_control(
"SETTINGS_INITIAL_WINDOW_SIZE exceeds maximum",
));
}
if id == 0x5 && !(MIN_MAX_FRAME_SIZE..=MAX_FRAME_SIZE).contains(&value) {
return Err(H2Error::protocol("SETTINGS_MAX_FRAME_SIZE out of bounds"));
}
if let Some(setting) = Setting::from_id_value(id, value) {
settings.push(setting);
}
cursor = cursor.saturating_add(6);
}
Ok(Self { settings, ack })
}
#[inline]
pub fn encode(&self, dst: &mut BytesMut) -> Result<(), H2Error> {
let mut flags = 0u8;
if self.ack {
flags |= settings_flags::ACK;
}
let encoded_settings: &[Setting] = if self.ack { &[] } else { &self.settings };
for setting in encoded_settings {
validate_setting_for_encode(*setting)?;
}
let header = FrameHeader {
length: checked_frame_length_product(encoded_settings.len(), 6)?,
frame_type: FrameType::Settings as u8,
flags,
stream_id: 0,
};
header.write(dst);
for setting in encoded_settings {
dst.put_u16(setting.id());
dst.put_u32(setting.value());
}
Ok(())
}
}
fn validate_setting_for_encode(setting: Setting) -> Result<(), H2Error> {
match setting {
Setting::InitialWindowSize(value) if value > MAX_INITIAL_WINDOW_SIZE => Err(
H2Error::flow_control("SETTINGS_INITIAL_WINDOW_SIZE exceeds maximum"),
),
Setting::MaxFrameSize(value) if !(MIN_MAX_FRAME_SIZE..=MAX_FRAME_SIZE).contains(&value) => {
Err(H2Error::protocol("SETTINGS_MAX_FRAME_SIZE out of bounds"))
}
_ => Ok(()),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Setting {
HeaderTableSize(u32),
EnablePush(bool),
MaxConcurrentStreams(u32),
InitialWindowSize(u32),
MaxFrameSize(u32),
MaxHeaderListSize(u32),
}
impl Setting {
#[must_use]
pub fn from_id_value(id: u16, value: u32) -> Option<Self> {
match id {
0x1 => Some(Self::HeaderTableSize(value)),
0x2 => Some(Self::EnablePush(value != 0)),
0x3 => Some(Self::MaxConcurrentStreams(value)),
0x4 => Some(Self::InitialWindowSize(value)),
0x5 => Some(Self::MaxFrameSize(value)),
0x6 => Some(Self::MaxHeaderListSize(value)),
_ => None, }
}
#[must_use]
pub fn id(&self) -> u16 {
match self {
Self::HeaderTableSize(_) => 0x1,
Self::EnablePush(_) => 0x2,
Self::MaxConcurrentStreams(_) => 0x3,
Self::InitialWindowSize(_) => 0x4,
Self::MaxFrameSize(_) => 0x5,
Self::MaxHeaderListSize(_) => 0x6,
}
}
#[must_use]
pub fn value(&self) -> u32 {
match self {
Self::HeaderTableSize(v)
| Self::MaxConcurrentStreams(v)
| Self::InitialWindowSize(v)
| Self::MaxFrameSize(v)
| Self::MaxHeaderListSize(v) => *v,
Self::EnablePush(v) => u32::from(*v),
}
}
}
#[derive(Debug, Clone)]
pub struct PushPromiseFrame {
pub stream_id: u32,
pub promised_stream_id: u32,
pub header_block: Bytes,
pub end_headers: bool,
}
impl PushPromiseFrame {
pub fn parse(header: &FrameHeader, mut payload: Bytes) -> Result<Self, H2Error> {
if header.stream_id == 0 {
return Err(H2Error::protocol("PUSH_PROMISE frame with stream ID 0"));
}
let end_headers = header.has_flag(headers_flags::END_HEADERS);
let padded = header.has_flag(headers_flags::PADDED);
let mut pad_length = 0;
if padded {
if payload.is_empty() {
return Err(H2Error::protocol(
"PADDED PUSH_PROMISE frame with no padding length",
));
}
pad_length = payload[0] as usize;
payload = payload.slice(1..);
}
if payload.len() < 4 {
return Err(H2Error::protocol("PUSH_PROMISE frame too short"));
}
let promised_stream_id = ((u32::from(payload[0]) & 0x7f) << 24)
| ((u32::from(payload[1])) << 16)
| ((u32::from(payload[2])) << 8)
| u32::from(payload[3]);
if promised_stream_id == 0 {
return Err(H2Error::protocol(
"PUSH_PROMISE frame with promised stream ID 0",
));
}
payload = payload.slice(4..);
if padded {
if pad_length > payload.len() {
return Err(H2Error::protocol(
"PUSH_PROMISE frame padding exceeds data length",
));
}
payload = payload.slice(..payload.len() - pad_length);
}
Ok(Self {
stream_id: header.stream_id,
promised_stream_id,
header_block: payload,
end_headers,
})
}
#[inline]
pub fn encode(&self, dst: &mut BytesMut) -> Result<(), H2Error> {
if self.stream_id == 0 {
return Err(H2Error::protocol("PUSH_PROMISE frame with stream ID 0"));
}
if self.promised_stream_id.trailing_zeros() >= 31 {
return Err(H2Error::protocol(
"PUSH_PROMISE frame with promised stream ID 0",
));
}
let mut flags = 0u8;
if self.end_headers {
flags |= headers_flags::END_HEADERS;
}
let header = FrameHeader {
length: checked_frame_length_sum(self.header_block.len(), 4)?,
frame_type: FrameType::PushPromise as u8,
flags,
stream_id: self.stream_id,
};
header.write(dst);
dst.put_u32(self.promised_stream_id & 0x7fff_ffff);
dst.extend_from_slice(&self.header_block);
Ok(())
}
}
#[derive(Debug, Clone, Copy)]
pub struct PingFrame {
pub opaque_data: [u8; 8],
pub ack: bool,
}
impl PingFrame {
#[must_use]
pub fn new(opaque_data: [u8; 8]) -> Self {
Self {
opaque_data,
ack: false,
}
}
#[must_use]
pub fn ack(opaque_data: [u8; 8]) -> Self {
Self {
opaque_data,
ack: true,
}
}
pub fn parse(header: &FrameHeader, payload: &Bytes) -> Result<Self, H2Error> {
if header.stream_id != 0 {
return Err(H2Error::protocol("PING frame with non-zero stream ID"));
}
if payload.len() != 8 {
return Err(H2Error::frame_size("PING frame must be 8 bytes"));
}
let mut opaque_data = [0u8; 8];
opaque_data.copy_from_slice(&payload[..8]);
Ok(Self {
opaque_data,
ack: header.has_flag(ping_flags::ACK),
})
}
#[inline]
pub fn encode(&self, dst: &mut BytesMut) -> Result<(), H2Error> {
let mut flags = 0u8;
if self.ack {
flags |= ping_flags::ACK;
}
let header = FrameHeader {
length: 8,
frame_type: FrameType::Ping as u8,
flags,
stream_id: 0,
};
header.write(dst);
dst.extend_from_slice(&self.opaque_data);
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct GoAwayFrame {
pub last_stream_id: u32,
pub error_code: ErrorCode,
pub debug_data: Bytes,
}
impl GoAwayFrame {
#[must_use]
pub fn new(last_stream_id: u32, error_code: ErrorCode) -> Self {
Self {
last_stream_id,
error_code,
debug_data: Bytes::new(),
}
}
pub fn parse(header: &FrameHeader, payload: &Bytes) -> Result<Self, H2Error> {
if header.stream_id != 0 {
return Err(H2Error::protocol("GOAWAY frame with non-zero stream ID"));
}
if payload.len() < 8 {
return Err(H2Error::frame_size("GOAWAY frame must be at least 8 bytes"));
}
let last_stream_id = ((u32::from(payload[0]) & 0x7f) << 24)
| ((u32::from(payload[1])) << 16)
| ((u32::from(payload[2])) << 8)
| u32::from(payload[3]);
let error_code = ErrorCode::from_u32(
((u32::from(payload[4])) << 24)
| ((u32::from(payload[5])) << 16)
| ((u32::from(payload[6])) << 8)
| u32::from(payload[7]),
);
let debug_data = payload.slice(8..);
Ok(Self {
last_stream_id,
error_code,
debug_data,
})
}
#[inline]
pub fn encode(&self, dst: &mut BytesMut) -> Result<(), H2Error> {
let header = FrameHeader {
length: checked_frame_length_sum(self.debug_data.len(), 8)?,
frame_type: FrameType::GoAway as u8,
flags: 0,
stream_id: 0,
};
header.write(dst);
dst.put_u32(self.last_stream_id & 0x7fff_ffff);
dst.put_u32(self.error_code.into());
dst.extend_from_slice(&self.debug_data);
Ok(())
}
}
#[derive(Debug, Clone, Copy)]
pub struct WindowUpdateFrame {
pub stream_id: u32,
pub increment: u32,
}
impl WindowUpdateFrame {
#[must_use]
pub fn new(stream_id: u32, increment: u32) -> Self {
Self {
stream_id,
increment,
}
}
pub fn parse(header: &FrameHeader, payload: &Bytes) -> Result<Self, H2Error> {
if payload.len() != 4 {
return Err(H2Error::frame_size("WINDOW_UPDATE frame must be 4 bytes"));
}
let increment = ((u32::from(payload[0]) & 0x7f) << 24)
| ((u32::from(payload[1])) << 16)
| ((u32::from(payload[2])) << 8)
| u32::from(payload[3]);
if increment == 0 {
return Err(invalid_window_update_increment_error(
header.stream_id,
"WINDOW_UPDATE with zero increment",
));
}
Ok(Self {
stream_id: header.stream_id,
increment,
})
}
#[inline]
pub fn encode(&self, dst: &mut BytesMut) -> Result<(), H2Error> {
if self.increment == 0 {
return Err(invalid_window_update_increment_error(
self.stream_id,
"WINDOW_UPDATE with zero increment",
));
}
if self.increment > MAX_WINDOW_UPDATE_INCREMENT {
return Err(invalid_window_update_increment_error(
self.stream_id,
format!(
"WINDOW_UPDATE increment {} exceeds 31-bit maximum",
self.increment
),
));
}
let header = FrameHeader {
length: 4,
frame_type: FrameType::WindowUpdate as u8,
flags: 0,
stream_id: self.stream_id,
};
header.write(dst);
dst.put_u32(self.increment);
Ok(())
}
}
fn invalid_window_update_increment_error(stream_id: u32, message: impl Into<String>) -> H2Error {
if stream_id == 0 {
H2Error::protocol(message)
} else {
H2Error::stream(stream_id, ErrorCode::ProtocolError, message)
}
}
#[derive(Debug, Clone)]
pub struct ContinuationFrame {
pub stream_id: u32,
pub header_block: Bytes,
pub end_headers: bool,
}
impl ContinuationFrame {
pub fn parse(header: &FrameHeader, payload: Bytes) -> Result<Self, H2Error> {
if header.stream_id == 0 {
return Err(H2Error::protocol("CONTINUATION frame with stream ID 0"));
}
Ok(Self {
stream_id: header.stream_id,
header_block: payload,
end_headers: header.has_flag(continuation_flags::END_HEADERS),
})
}
#[inline]
pub fn encode(&self, dst: &mut BytesMut) -> Result<(), H2Error> {
if self.stream_id == 0 {
return Err(H2Error::protocol("CONTINUATION frame with stream ID 0"));
}
let mut flags = 0u8;
if self.end_headers {
flags |= continuation_flags::END_HEADERS;
}
let header = FrameHeader {
length: try_frame_length(self.header_block.len())?,
frame_type: FrameType::Continuation as u8,
flags,
stream_id: self.stream_id,
};
header.write(dst);
dst.extend_from_slice(&self.header_block);
Ok(())
}
}
pub fn parse_frame(header: &FrameHeader, payload: Bytes) -> Result<Frame, H2Error> {
let frame_type = FrameType::from_u8(header.frame_type);
match frame_type {
Some(FrameType::Data) => Ok(Frame::Data(DataFrame::parse(header, payload)?)),
Some(FrameType::Headers) => Ok(Frame::Headers(HeadersFrame::parse(header, payload)?)),
Some(FrameType::Priority) => Ok(Frame::Priority(PriorityFrame::parse(header, &payload)?)),
Some(FrameType::RstStream) => {
Ok(Frame::RstStream(RstStreamFrame::parse(header, &payload)?))
}
Some(FrameType::Settings) => Ok(Frame::Settings(SettingsFrame::parse(header, &payload)?)),
Some(FrameType::PushPromise) => Ok(Frame::PushPromise(PushPromiseFrame::parse(
header, payload,
)?)),
Some(FrameType::Ping) => Ok(Frame::Ping(PingFrame::parse(header, &payload)?)),
Some(FrameType::GoAway) => Ok(Frame::GoAway(GoAwayFrame::parse(header, &payload)?)),
Some(FrameType::WindowUpdate) => Ok(Frame::WindowUpdate(WindowUpdateFrame::parse(
header, &payload,
)?)),
Some(FrameType::Continuation) => Ok(Frame::Continuation(ContinuationFrame::parse(
header, payload,
)?)),
None => Ok(Frame::Unknown {
frame_type: header.frame_type,
stream_id: header.stream_id,
payload,
}),
}
}
#[cfg(test)]
mod tests {
#![allow(
clippy::pedantic,
clippy::nursery,
clippy::expect_fun_call,
clippy::map_unwrap_or,
clippy::cast_possible_wrap,
clippy::future_not_send
)]
use super::*;
#[test]
fn test_frame_header_roundtrip() {
let original = FrameHeader {
length: 0x0012_3456,
frame_type: FrameType::Data as u8,
flags: 0x05,
stream_id: 0x7654_3210,
};
let mut buf = BytesMut::new();
original.write(&mut buf);
assert_eq!(buf.len(), FRAME_HEADER_SIZE);
let parsed = FrameHeader::parse(&mut buf).unwrap();
assert_eq!(parsed.length, original.length);
assert_eq!(parsed.frame_type, original.frame_type);
assert_eq!(parsed.flags, original.flags);
assert_eq!(parsed.stream_id, original.stream_id & 0x7fff_ffff);
}
#[test]
fn frame_header_reserved_stream_id_bit_is_ignored_on_parse_and_write() {
let mut wire = BytesMut::from(&[0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x01][..]);
let parsed = FrameHeader::parse(&mut wire).expect("reserved bit is not part of stream id");
assert_eq!(parsed.stream_id, 1);
assert!(wire.is_empty());
let mut encoded = BytesMut::new();
FrameHeader {
length: 0,
frame_type: FrameType::Data as u8,
flags: 0,
stream_id: 0x8000_0001,
}
.write(&mut encoded);
assert_eq!(&encoded[5..9], &[0x00, 0x00, 0x00, 0x01]);
}
#[test]
fn test_frame_header_parse_insufficient_bytes() {
let mut buf = BytesMut::from(&b"\x00\x00\x00\x00\x00\x00\x00\x00"[..]);
let err = FrameHeader::parse(&mut buf).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn test_data_frame_roundtrip() {
let original = DataFrame::new(1, Bytes::from_static(b"hello"), true);
let mut buf = BytesMut::new();
original.encode(&mut buf).expect("encode");
let header = FrameHeader::parse(&mut buf).unwrap();
let payload = buf.split_to(header.length as usize).freeze();
let parsed = DataFrame::parse(&header, payload).unwrap();
assert_eq!(parsed.stream_id, original.stream_id);
assert_eq!(parsed.data, original.data);
assert_eq!(parsed.end_stream, original.end_stream);
}
#[test]
fn test_settings_frame_roundtrip() {
let original = SettingsFrame::new(vec![
Setting::HeaderTableSize(4096),
Setting::MaxConcurrentStreams(100),
Setting::InitialWindowSize(65535),
]);
let mut buf = BytesMut::new();
original.encode(&mut buf).expect("encode");
let header = FrameHeader::parse(&mut buf).unwrap();
let payload = buf.split_to(header.length as usize).freeze();
let parsed = SettingsFrame::parse(&header, &payload).unwrap();
assert!(!parsed.ack);
assert_eq!(parsed.settings.len(), 3);
}
#[test]
fn test_settings_ack() {
let original = SettingsFrame::ack();
let mut buf = BytesMut::new();
original.encode(&mut buf).expect("encode");
let header = FrameHeader::parse(&mut buf).unwrap();
assert_eq!(header.length, 0);
assert!(header.has_flag(settings_flags::ACK));
}
#[test]
fn test_settings_ack_serialization_drops_payload_per_rfc9113() {
let mut buf = BytesMut::new();
SettingsFrame {
settings: vec![Setting::HeaderTableSize(4096)],
ack: true,
}
.encode(&mut buf)
.expect("encode");
let header = FrameHeader::parse(&mut buf).unwrap();
assert_eq!(header.length, 0, "SETTINGS ACK must be zero-length");
assert!(header.has_flag(settings_flags::ACK));
assert!(
buf.is_empty(),
"SETTINGS ACK must not carry serialized settings payload"
);
}
#[test]
fn test_ping_roundtrip() {
let original = PingFrame::new([1, 2, 3, 4, 5, 6, 7, 8]);
let mut buf = BytesMut::new();
original.encode(&mut buf).expect("encode");
let header = FrameHeader::parse(&mut buf).unwrap();
let payload = buf.split_to(header.length as usize).freeze();
let parsed = PingFrame::parse(&header, &payload).unwrap();
assert_eq!(parsed.opaque_data, original.opaque_data);
assert!(!parsed.ack);
}
#[test]
fn test_goaway_roundtrip() {
let original = GoAwayFrame::new(100, ErrorCode::NoError);
let mut buf = BytesMut::new();
original.encode(&mut buf).expect("encode");
let header = FrameHeader::parse(&mut buf).unwrap();
let payload = buf.split_to(header.length as usize).freeze();
let parsed = GoAwayFrame::parse(&header, &payload).unwrap();
assert_eq!(parsed.last_stream_id, 100);
assert_eq!(parsed.error_code, ErrorCode::NoError);
}
#[test]
fn test_window_update_roundtrip() {
let original = WindowUpdateFrame::new(1, 65535);
let mut buf = BytesMut::new();
original.encode(&mut buf).expect("encode");
let header = FrameHeader::parse(&mut buf).unwrap();
let payload = buf.split_to(header.length as usize).freeze();
let parsed = WindowUpdateFrame::parse(&header, &payload).unwrap();
assert_eq!(parsed.stream_id, 1);
assert_eq!(parsed.increment, 65535);
}
#[test]
fn test_headers_frame_roundtrip() {
let original = HeadersFrame::new(3, Bytes::from_static(b"header-block"), false, true);
let mut buf = BytesMut::new();
original.encode(&mut buf).expect("encode");
let header = FrameHeader::parse(&mut buf).unwrap();
let payload = buf.split_to(header.length as usize).freeze();
let parsed = HeadersFrame::parse(&header, payload).unwrap();
assert_eq!(parsed.stream_id, original.stream_id);
assert_eq!(parsed.header_block, original.header_block);
assert_eq!(parsed.end_stream, original.end_stream);
assert_eq!(parsed.end_headers, original.end_headers);
}
#[test]
fn test_headers_frame_with_priority_roundtrip() {
let mut original = HeadersFrame::new(5, Bytes::from_static(b"hdr"), true, true);
original.priority = Some(PrioritySpec {
exclusive: true,
dependency: 1,
weight: 128,
});
let mut buf = BytesMut::new();
original.encode(&mut buf).expect("encode");
let header = FrameHeader::parse(&mut buf).unwrap();
let payload = buf.split_to(header.length as usize).freeze();
let parsed = HeadersFrame::parse(&header, payload).unwrap();
assert_eq!(parsed.stream_id, 5);
assert!(parsed.end_stream);
assert!(parsed.priority.is_some());
let p = parsed.priority.unwrap();
assert!(p.exclusive);
assert_eq!(p.dependency, 1);
assert_eq!(p.weight, 128);
}
#[test]
fn test_priority_frame_roundtrip() {
let original = PriorityFrame {
stream_id: 7,
priority: PrioritySpec {
exclusive: false,
dependency: 3,
weight: 64,
},
};
let mut buf = BytesMut::new();
original.encode(&mut buf).expect("encode");
let header = FrameHeader::parse(&mut buf).unwrap();
let payload = buf.split_to(header.length as usize).freeze();
let parsed = PriorityFrame::parse(&header, &payload).unwrap();
assert_eq!(parsed.stream_id, original.stream_id);
assert_eq!(parsed.priority.exclusive, original.priority.exclusive);
assert_eq!(parsed.priority.dependency, original.priority.dependency);
assert_eq!(parsed.priority.weight, original.priority.weight);
}
#[test]
fn test_rst_stream_roundtrip() {
let original = RstStreamFrame::new(11, ErrorCode::Cancel);
let mut buf = BytesMut::new();
original.encode(&mut buf).expect("encode");
let header = FrameHeader::parse(&mut buf).unwrap();
let payload = buf.split_to(header.length as usize).freeze();
let parsed = RstStreamFrame::parse(&header, &payload).unwrap();
assert_eq!(parsed.stream_id, original.stream_id);
assert_eq!(parsed.error_code, original.error_code);
}
#[test]
fn test_push_promise_roundtrip() {
let original = PushPromiseFrame {
stream_id: 1,
promised_stream_id: 2,
header_block: Bytes::from_static(b"pushed-headers"),
end_headers: true,
};
let mut buf = BytesMut::new();
original.encode(&mut buf).expect("encode");
let header = FrameHeader::parse(&mut buf).unwrap();
let payload = buf.split_to(header.length as usize).freeze();
let parsed = PushPromiseFrame::parse(&header, payload).unwrap();
assert_eq!(parsed.stream_id, original.stream_id);
assert_eq!(parsed.promised_stream_id, original.promised_stream_id);
assert_eq!(parsed.header_block, original.header_block);
assert_eq!(parsed.end_headers, original.end_headers);
}
#[test]
fn test_continuation_roundtrip() {
let original = ContinuationFrame {
stream_id: 9,
header_block: Bytes::from_static(b"continued-headers"),
end_headers: false,
};
let mut buf = BytesMut::new();
original.encode(&mut buf).expect("encode");
let header = FrameHeader::parse(&mut buf).unwrap();
let payload = buf.split_to(header.length as usize).freeze();
let parsed = ContinuationFrame::parse(&header, payload).unwrap();
assert_eq!(parsed.stream_id, original.stream_id);
assert_eq!(parsed.header_block, original.header_block);
assert_eq!(parsed.end_headers, original.end_headers);
}
#[test]
fn test_data_frame_stream_id_zero_rejected() {
let header = FrameHeader {
length: 5,
frame_type: FrameType::Data as u8,
flags: 0,
stream_id: 0,
};
let payload = Bytes::from_static(b"hello");
let err = DataFrame::parse(&header, payload).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn data_frame_encode_rejects_stream_id_zero_without_partial_write() {
let frame = DataFrame::new(0, Bytes::from_static(b"hello"), true);
let mut buf = BytesMut::new();
let err = frame
.encode(&mut buf)
.expect_err("invalid DATA stream id must not encode");
assert_eq!(err.code, ErrorCode::ProtocolError);
assert_eq!(err.stream_id, None);
assert!(
buf.is_empty(),
"invalid DATA frame must not partially encode"
);
}
#[test]
fn test_data_frame_invalid_padding() {
let header = FrameHeader {
length: 5,
frame_type: FrameType::Data as u8,
flags: data_flags::PADDED,
stream_id: 1,
};
let payload = Bytes::from_static(&[10, b'a', b'b', b'c', b'd']);
let err = DataFrame::parse(&header, payload).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn test_data_frame_padded_empty_payload_rejected() {
let header = FrameHeader {
length: 0,
frame_type: FrameType::Data as u8,
flags: data_flags::PADDED,
stream_id: 1,
};
let payload = Bytes::new();
let err = DataFrame::parse(&header, payload).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn test_data_frame_padding_exact_length_accepted() {
let header = FrameHeader {
length: 2,
frame_type: FrameType::Data as u8,
flags: data_flags::PADDED,
stream_id: 1,
};
let payload = Bytes::from_static(&[1, 0xff]);
let parsed = DataFrame::parse(&header, payload).unwrap();
assert_eq!(parsed.stream_id, 1);
assert!(parsed.data.is_empty());
}
#[test]
fn test_headers_frame_stream_id_zero_rejected() {
let header = FrameHeader {
length: 5,
frame_type: FrameType::Headers as u8,
flags: headers_flags::END_HEADERS,
stream_id: 0,
};
let payload = Bytes::from_static(b"hdr");
let err = HeadersFrame::parse(&header, payload).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn headers_frame_encode_rejects_stream_id_zero_without_partial_write() {
let frame = HeadersFrame::new(0, Bytes::from_static(b"hdr"), false, true);
let mut buf = BytesMut::new();
let err = frame
.encode(&mut buf)
.expect_err("invalid HEADERS stream id must not encode");
assert_eq!(err.code, ErrorCode::ProtocolError);
assert_eq!(err.stream_id, None);
assert!(
buf.is_empty(),
"invalid HEADERS frame must not partially encode"
);
}
#[test]
fn headers_frame_encode_rejects_self_dependency_without_partial_write() {
let mut frame = HeadersFrame::new(3, Bytes::from_static(b"hdr"), false, true);
frame.priority = Some(PrioritySpec {
exclusive: false,
dependency: 3,
weight: 16,
});
let mut buf = BytesMut::new();
let err = frame
.encode(&mut buf)
.expect_err("self-dependent HEADERS frame must not encode");
assert_eq!(err.code, ErrorCode::ProtocolError);
assert_eq!(err.stream_id, Some(3));
assert!(
buf.is_empty(),
"invalid HEADERS frame must not partially encode"
);
}
#[test]
fn ujytci_headers_padded_priority_overflow_returns_connection_error() {
let header = FrameHeader {
length: 7,
frame_type: FrameType::Headers as u8,
flags: headers_flags::PADDED | headers_flags::PRIORITY | headers_flags::END_HEADERS,
stream_id: 1,
};
let payload = Bytes::from_static(&[
10, 0, 0, 0, 1, 16, 0, ]);
let err = HeadersFrame::parse(&header, payload).expect_err("must reject");
assert_eq!(err.code, ErrorCode::ProtocolError);
assert!(
err.stream_id.is_none(),
"padding-overflow MUST be connection-level (stream_id=None), got {:?}",
err.stream_id
);
assert!(
err.message.contains("padding exceeds"),
"expected padding-overflow message, got: {}",
err.message
);
}
#[test]
fn ujytci_headers_padded_priority_valid_parses_successfully() {
let header = FrameHeader {
length: 9,
frame_type: FrameType::Headers as u8,
flags: headers_flags::PADDED | headers_flags::PRIORITY | headers_flags::END_HEADERS,
stream_id: 1,
};
let payload = Bytes::from_static(&[
2, 0, 0, 0, 5, 8, 0xFE, 0, 0, ]);
let parsed = HeadersFrame::parse(&header, payload).expect("must parse");
assert_eq!(parsed.priority.expect("priority present").dependency, 5);
assert_eq!(parsed.header_block.len(), 1);
assert_eq!(parsed.header_block[0], 0xFE);
}
#[test]
fn test_headers_frame_priority_too_short() {
let header = FrameHeader {
length: 3,
frame_type: FrameType::Headers as u8,
flags: headers_flags::PRIORITY | headers_flags::END_HEADERS,
stream_id: 1,
};
let payload = Bytes::from_static(&[0, 0, 0]);
let err = HeadersFrame::parse(&header, payload).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn test_headers_frame_padded_empty_payload_rejected() {
let header = FrameHeader {
length: 0,
frame_type: FrameType::Headers as u8,
flags: headers_flags::PADDED | headers_flags::END_HEADERS,
stream_id: 1,
};
let payload = Bytes::new();
let err = HeadersFrame::parse(&header, payload).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn test_headers_frame_padding_exact_length_accepted() {
let header = FrameHeader {
length: 3,
frame_type: FrameType::Headers as u8,
flags: headers_flags::PADDED | headers_flags::END_HEADERS,
stream_id: 1,
};
let payload = Bytes::from_static(&[2, b'a', b'b']);
let parsed = HeadersFrame::parse(&header, payload).unwrap();
assert_eq!(parsed.stream_id, 1);
assert!(parsed.header_block.is_empty());
}
#[test]
fn test_priority_frame_stream_id_zero_rejected() {
let header = FrameHeader {
length: 5,
frame_type: FrameType::Priority as u8,
flags: 0,
stream_id: 0,
};
let payload = Bytes::from_static(&[0, 0, 0, 1, 16]);
let err = PriorityFrame::parse(&header, &payload).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn priority_frame_encode_rejects_stream_id_zero_without_partial_write() {
let frame = PriorityFrame {
stream_id: 0,
priority: PrioritySpec {
exclusive: false,
dependency: 1,
weight: 16,
},
};
let mut buf = BytesMut::new();
let err = frame
.encode(&mut buf)
.expect_err("invalid PRIORITY stream id must not encode");
assert_eq!(err.code, ErrorCode::ProtocolError);
assert_eq!(err.stream_id, None);
assert!(
buf.is_empty(),
"invalid PRIORITY frame must not partially encode"
);
}
#[test]
fn priority_frame_encode_rejects_self_dependency_without_partial_write() {
let frame = PriorityFrame {
stream_id: 5,
priority: PrioritySpec {
exclusive: true,
dependency: 5,
weight: 16,
},
};
let mut buf = BytesMut::new();
let err = frame
.encode(&mut buf)
.expect_err("self-dependent PRIORITY frame must not encode");
assert_eq!(err.code, ErrorCode::ProtocolError);
assert_eq!(err.stream_id, Some(5));
assert!(
buf.is_empty(),
"invalid PRIORITY frame must not partially encode"
);
}
#[test]
fn test_priority_frame_wrong_size() {
let header = FrameHeader {
length: 4,
frame_type: FrameType::Priority as u8,
flags: 0,
stream_id: 1,
};
let payload = Bytes::from_static(&[0, 0, 0, 1]);
let err = PriorityFrame::parse(&header, &payload).unwrap_err();
assert_eq!(err.code, ErrorCode::FrameSizeError);
assert_eq!(err.stream_id, Some(1));
}
#[test]
fn test_rst_stream_stream_id_zero_rejected() {
let header = FrameHeader {
length: 4,
frame_type: FrameType::RstStream as u8,
flags: 0,
stream_id: 0,
};
let payload = Bytes::from_static(&[0, 0, 0, 0]);
let err = RstStreamFrame::parse(&header, &payload).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn rst_stream_encode_rejects_stream_id_zero_without_partial_write() {
let frame = RstStreamFrame::new(0, ErrorCode::Cancel);
let mut buf = BytesMut::new();
let err = frame
.encode(&mut buf)
.expect_err("invalid RST_STREAM stream id must not encode");
assert_eq!(err.code, ErrorCode::ProtocolError);
assert_eq!(err.stream_id, None);
assert!(
buf.is_empty(),
"invalid RST_STREAM frame must not partially encode"
);
}
#[test]
fn test_rst_stream_wrong_size() {
let header = FrameHeader {
length: 3,
frame_type: FrameType::RstStream as u8,
flags: 0,
stream_id: 1,
};
let payload = Bytes::from_static(&[0, 0, 0]);
let err = RstStreamFrame::parse(&header, &payload).unwrap_err();
assert_eq!(err.code, ErrorCode::FrameSizeError);
}
#[test]
fn test_settings_frame_non_zero_stream_id_rejected() {
let header = FrameHeader {
length: 0,
frame_type: FrameType::Settings as u8,
flags: 0,
stream_id: 1,
};
let payload = Bytes::new();
let err = SettingsFrame::parse(&header, &payload).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn test_settings_ack_with_payload_rejected() {
let header = FrameHeader {
length: 6,
frame_type: FrameType::Settings as u8,
flags: settings_flags::ACK,
stream_id: 0,
};
let payload = Bytes::from_static(&[0, 1, 0, 0, 0, 1]);
let err = SettingsFrame::parse(&header, &payload).unwrap_err();
assert_eq!(err.code, ErrorCode::FrameSizeError);
}
#[test]
fn test_settings_wrong_length() {
let header = FrameHeader {
length: 5, frame_type: FrameType::Settings as u8,
flags: 0,
stream_id: 0,
};
let payload = Bytes::from_static(&[0, 1, 0, 0, 0]);
let err = SettingsFrame::parse(&header, &payload).unwrap_err();
assert_eq!(err.code, ErrorCode::FrameSizeError);
}
#[test]
fn test_push_promise_stream_id_zero_rejected() {
let header = FrameHeader {
length: 5,
frame_type: FrameType::PushPromise as u8,
flags: headers_flags::END_HEADERS,
stream_id: 0,
};
let payload = Bytes::from_static(&[0, 0, 0, 2, 0]);
let err = PushPromiseFrame::parse(&header, payload).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn test_push_promise_too_short() {
let header = FrameHeader {
length: 3,
frame_type: FrameType::PushPromise as u8,
flags: headers_flags::END_HEADERS,
stream_id: 1,
};
let payload = Bytes::from_static(&[0, 0, 2]);
let err = PushPromiseFrame::parse(&header, payload).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn test_push_promise_promised_stream_id_zero_rejected() {
let header = FrameHeader {
length: 5,
frame_type: FrameType::PushPromise as u8,
flags: headers_flags::END_HEADERS,
stream_id: 1,
};
let payload = Bytes::from_static(&[0, 0, 0, 0, 0]);
let err = PushPromiseFrame::parse(&header, payload).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn push_promise_encode_rejects_invalid_stream_ids_without_partial_write() {
let cases = [
PushPromiseFrame {
stream_id: 0,
promised_stream_id: 2,
header_block: Bytes::from_static(b"hdr"),
end_headers: true,
},
PushPromiseFrame {
stream_id: 1,
promised_stream_id: 0,
header_block: Bytes::from_static(b"hdr"),
end_headers: true,
},
PushPromiseFrame {
stream_id: 1,
promised_stream_id: 0x8000_0000,
header_block: Bytes::from_static(b"hdr"),
end_headers: true,
},
];
for frame in cases {
let mut buf = BytesMut::new();
let err = frame
.encode(&mut buf)
.expect_err("invalid PUSH_PROMISE ids must not encode");
assert_eq!(err.code, ErrorCode::ProtocolError);
assert_eq!(err.stream_id, None);
assert!(
buf.is_empty(),
"invalid PUSH_PROMISE must not partially encode"
);
}
}
#[test]
fn test_push_promise_padded_empty_payload_rejected() {
let header = FrameHeader {
length: 0,
frame_type: FrameType::PushPromise as u8,
flags: headers_flags::PADDED | headers_flags::END_HEADERS,
stream_id: 1,
};
let payload = Bytes::new();
let err = PushPromiseFrame::parse(&header, payload).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn test_push_promise_padding_exceeds_length() {
let header = FrameHeader {
length: 5,
frame_type: FrameType::PushPromise as u8,
flags: headers_flags::PADDED | headers_flags::END_HEADERS,
stream_id: 1,
};
let payload = Bytes::from_static(&[1, 0, 0, 0, 1]);
let err = PushPromiseFrame::parse(&header, payload).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn test_ping_non_zero_stream_id_rejected() {
let header = FrameHeader {
length: 8,
frame_type: FrameType::Ping as u8,
flags: 0,
stream_id: 1,
};
let payload = Bytes::from_static(&[0, 0, 0, 0, 0, 0, 0, 0]);
let err = PingFrame::parse(&header, &payload).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn test_ping_wrong_size() {
let header = FrameHeader {
length: 7,
frame_type: FrameType::Ping as u8,
flags: 0,
stream_id: 0,
};
let payload = Bytes::from_static(&[0, 0, 0, 0, 0, 0, 0]);
let err = PingFrame::parse(&header, &payload).unwrap_err();
assert_eq!(err.code, ErrorCode::FrameSizeError);
}
#[test]
fn test_ping_rejects_every_non_eight_byte_payload_size() {
for payload_len in [0usize, 1, 7, 9, 16] {
let header = FrameHeader {
length: u32::try_from(payload_len).expect("test payload length fits u32"),
frame_type: FrameType::Ping as u8,
flags: 0,
stream_id: 0,
};
let payload = Bytes::from(vec![0; payload_len]);
let err = PingFrame::parse(&header, &payload).unwrap_err();
assert_eq!(
err.code,
ErrorCode::FrameSizeError,
"PING payload length {payload_len} must be rejected before connection handling"
);
}
}
#[test]
fn test_goaway_non_zero_stream_id_rejected() {
let header = FrameHeader {
length: 8,
frame_type: FrameType::GoAway as u8,
flags: 0,
stream_id: 1,
};
let payload = Bytes::from_static(&[0, 0, 0, 0, 0, 0, 0, 0]);
let err = GoAwayFrame::parse(&header, &payload).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn test_goaway_too_short() {
let header = FrameHeader {
length: 7,
frame_type: FrameType::GoAway as u8,
flags: 0,
stream_id: 0,
};
let payload = Bytes::from_static(&[0, 0, 0, 0, 0, 0, 0]);
let err = GoAwayFrame::parse(&header, &payload).unwrap_err();
assert_eq!(err.code, ErrorCode::FrameSizeError);
}
#[test]
fn test_goaway_with_debug_data() {
let mut original = GoAwayFrame::new(100, ErrorCode::EnhanceYourCalm);
original.debug_data = Bytes::from_static(b"too many requests");
let mut buf = BytesMut::new();
original.encode(&mut buf).expect("encode");
let header = FrameHeader::parse(&mut buf).unwrap();
let payload = buf.split_to(header.length as usize).freeze();
let parsed = GoAwayFrame::parse(&header, &payload).unwrap();
assert_eq!(parsed.last_stream_id, 100);
assert_eq!(parsed.error_code, ErrorCode::EnhanceYourCalm);
assert_eq!(&parsed.debug_data[..], b"too many requests");
}
#[test]
fn goaway_parse_ignores_reserved_last_stream_id_bit() {
let header = FrameHeader {
length: 8,
frame_type: FrameType::GoAway as u8,
flags: 0,
stream_id: 0,
};
for payload in [
Bytes::from_static(&[0x00, 0x00, 0x00, 0x09, 0, 0, 0, 0]),
Bytes::from_static(&[0x80, 0x00, 0x00, 0x09, 0, 0, 0, 0]),
] {
let parsed = GoAwayFrame::parse(&header, &payload).unwrap();
assert_eq!(parsed.last_stream_id, 9);
assert_eq!(parsed.error_code, ErrorCode::NoError);
assert!(parsed.debug_data.is_empty());
}
}
#[test]
fn test_window_update_wrong_size() {
let header = FrameHeader {
length: 3,
frame_type: FrameType::WindowUpdate as u8,
flags: 0,
stream_id: 0,
};
let payload = Bytes::from_static(&[0, 0, 0]);
let err = WindowUpdateFrame::parse(&header, &payload).unwrap_err();
assert_eq!(err.code, ErrorCode::FrameSizeError);
}
#[test]
fn test_window_update_zero_increment_rejected() {
let header = FrameHeader {
length: 4,
frame_type: FrameType::WindowUpdate as u8,
flags: 0,
stream_id: 0,
};
let payload = Bytes::from_static(&[0, 0, 0, 0]);
let err = WindowUpdateFrame::parse(&header, &payload).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
assert_eq!(err.stream_id, None); }
#[test]
fn test_window_update_zero_increment_stream_level_is_stream_error() {
let header = FrameHeader {
length: 4,
frame_type: FrameType::WindowUpdate as u8,
flags: 0,
stream_id: 3,
};
let payload = Bytes::from_static(&[0, 0, 0, 0]);
let err = WindowUpdateFrame::parse(&header, &payload).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
assert_eq!(err.stream_id, Some(3)); }
#[test]
fn window_update_encode_rejects_invalid_increment_bounds() {
let cases = [
(WindowUpdateFrame::new(0, 0), None),
(WindowUpdateFrame::new(3, 0), Some(3)),
(
WindowUpdateFrame::new(0, MAX_WINDOW_UPDATE_INCREMENT.saturating_add(1)),
None,
),
(
WindowUpdateFrame::new(3, MAX_WINDOW_UPDATE_INCREMENT.saturating_add(1)),
Some(3),
),
];
for (frame, expected_stream_id) in cases {
let mut buf = BytesMut::new();
let err = frame
.encode(&mut buf)
.expect_err("invalid increment must not encode");
assert_eq!(err.code, ErrorCode::ProtocolError);
assert_eq!(err.stream_id, expected_stream_id);
assert!(buf.is_empty(), "invalid frame must not partially encode");
}
}
#[test]
fn test_continuation_stream_id_zero_rejected() {
let header = FrameHeader {
length: 5,
frame_type: FrameType::Continuation as u8,
flags: continuation_flags::END_HEADERS,
stream_id: 0,
};
let payload = Bytes::from_static(b"hdr");
let err = ContinuationFrame::parse(&header, payload).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn continuation_encode_rejects_stream_id_zero_without_partial_write() {
let frame = ContinuationFrame {
stream_id: 0,
header_block: Bytes::from_static(b"hdr"),
end_headers: true,
};
let mut buf = BytesMut::new();
let err = frame
.encode(&mut buf)
.expect_err("invalid CONTINUATION stream id must not encode");
assert_eq!(err.code, ErrorCode::ProtocolError);
assert_eq!(err.stream_id, None);
assert!(
buf.is_empty(),
"invalid CONTINUATION must not partially encode"
);
}
#[test]
fn test_unknown_frame_type_ignored() {
let header = FrameHeader {
length: 3,
frame_type: 0xFF,
flags: 0,
stream_id: 7,
};
let payload = Bytes::from_static(b"abc");
let frame = parse_frame(&header, payload.clone()).unwrap();
match frame {
Frame::Unknown {
frame_type,
stream_id,
payload: p,
} => {
assert_eq!(frame_type, 0xFF);
assert_eq!(stream_id, 7);
assert_eq!(p, payload);
}
_ => panic!("expected Frame::Unknown"),
}
}
#[test]
fn test_max_frame_size_constants() {
assert_eq!(DEFAULT_MAX_FRAME_SIZE, 16_384);
assert_eq!(MAX_FRAME_SIZE, 16_777_215);
assert_eq!(MIN_MAX_FRAME_SIZE, 16_384);
const {
assert!(DEFAULT_MAX_FRAME_SIZE >= MIN_MAX_FRAME_SIZE);
assert!(DEFAULT_MAX_FRAME_SIZE <= MAX_FRAME_SIZE);
}
}
#[test]
fn test_frame_header_length_max() {
let header = FrameHeader {
length: MAX_FRAME_SIZE,
frame_type: FrameType::Data as u8,
flags: 0,
stream_id: 1,
};
let mut buf = BytesMut::new();
header.write(&mut buf);
let parsed = FrameHeader::parse(&mut buf).unwrap();
assert_eq!(parsed.length, MAX_FRAME_SIZE);
}
#[test]
fn frame_header_parse_masks_inbound_reserved_stream_bit() {
for stream_id in [0_u32, 1, 3, 0x7fff_ffff] {
let raw_without_reserved = stream_id.to_be_bytes();
let mut raw_with_reserved = raw_without_reserved;
raw_with_reserved[0] |= 0x80;
for raw_stream_id in [raw_without_reserved, raw_with_reserved] {
let mut buf = BytesMut::from(
&[
0,
0,
5,
FrameType::Headers as u8,
headers_flags::END_HEADERS,
raw_stream_id[0],
raw_stream_id[1],
raw_stream_id[2],
raw_stream_id[3],
b'h',
b'e',
b'l',
b'l',
b'o',
][..],
);
let parsed = FrameHeader::parse(&mut buf).unwrap();
assert_eq!(parsed.length, 5);
assert_eq!(parsed.frame_type, FrameType::Headers as u8);
assert_eq!(parsed.flags, headers_flags::END_HEADERS);
assert_eq!(parsed.stream_id, stream_id);
assert_eq!(&buf[..], b"hello", "header parse must leave payload bytes");
}
}
}
#[test]
fn frame_header_parse_consumes_exactly_one_header() {
let first = FrameHeader {
length: 0,
frame_type: FrameType::Settings as u8,
flags: settings_flags::ACK,
stream_id: 0,
};
let second = FrameHeader {
length: 8,
frame_type: FrameType::Ping as u8,
flags: 0,
stream_id: 0,
};
let mut buf = BytesMut::new();
first.write(&mut buf);
second.write(&mut buf);
buf.extend_from_slice(b"12345678");
let parsed_first = FrameHeader::parse(&mut buf).unwrap();
assert_eq!(parsed_first, first);
assert_eq!(buf.len(), FRAME_HEADER_SIZE + 8);
let parsed_second = FrameHeader::parse(&mut buf).unwrap();
assert_eq!(parsed_second, second);
assert_eq!(&buf[..], b"12345678");
}
#[test]
fn test_try_frame_length_rejects_above_max() {
let err = try_frame_length((MAX_FRAME_SIZE as usize).saturating_add(1)).unwrap_err();
assert_eq!(err.code, ErrorCode::FrameSizeError);
assert!(err.to_string().contains("payload length"));
}
#[test]
fn test_data_frame_encode_rejects_payload_above_max() {
let original = DataFrame::new(
1,
Bytes::from(vec![0_u8; (MAX_FRAME_SIZE as usize).saturating_add(1)]),
true,
);
let mut buf = BytesMut::new();
let err = original.encode(&mut buf).unwrap_err();
assert_eq!(err.code, ErrorCode::FrameSizeError);
assert!(buf.is_empty(), "oversized frame must not partially encode");
}
#[test]
fn test_stream_id_31_bits() {
let header = FrameHeader {
length: 0,
frame_type: FrameType::Data as u8,
flags: 0,
stream_id: 0x7FFF_FFFF, };
let mut buf = BytesMut::new();
header.write(&mut buf);
let parsed = FrameHeader::parse(&mut buf).unwrap();
assert_eq!(parsed.stream_id, 0x7FFF_FFFF);
}
#[test]
fn test_stream_id_reserved_bit_masked() {
let header = FrameHeader {
length: 0,
frame_type: FrameType::Data as u8,
flags: 0,
stream_id: 0xFFFF_FFFF,
};
let mut buf = BytesMut::new();
header.write(&mut buf);
let parsed = FrameHeader::parse(&mut buf).unwrap();
assert_eq!(parsed.stream_id, 0x7FFF_FFFF);
}
#[test]
fn test_frame_type_all_variants() {
assert_eq!(FrameType::from_u8(0x0), Some(FrameType::Data));
assert_eq!(FrameType::from_u8(0x1), Some(FrameType::Headers));
assert_eq!(FrameType::from_u8(0x2), Some(FrameType::Priority));
assert_eq!(FrameType::from_u8(0x3), Some(FrameType::RstStream));
assert_eq!(FrameType::from_u8(0x4), Some(FrameType::Settings));
assert_eq!(FrameType::from_u8(0x5), Some(FrameType::PushPromise));
assert_eq!(FrameType::from_u8(0x6), Some(FrameType::Ping));
assert_eq!(FrameType::from_u8(0x7), Some(FrameType::GoAway));
assert_eq!(FrameType::from_u8(0x8), Some(FrameType::WindowUpdate));
assert_eq!(FrameType::from_u8(0x9), Some(FrameType::Continuation));
assert_eq!(FrameType::from_u8(0xA), None);
assert_eq!(FrameType::from_u8(0xFF), None);
}
#[test]
fn test_setting_all_variants() {
assert_eq!(
Setting::from_id_value(0x1, 4096),
Some(Setting::HeaderTableSize(4096))
);
assert_eq!(
Setting::from_id_value(0x2, 1),
Some(Setting::EnablePush(true))
);
assert_eq!(
Setting::from_id_value(0x2, 0),
Some(Setting::EnablePush(false))
);
assert_eq!(
Setting::from_id_value(0x3, 100),
Some(Setting::MaxConcurrentStreams(100))
);
assert_eq!(
Setting::from_id_value(0x4, 65535),
Some(Setting::InitialWindowSize(65535))
);
assert_eq!(
Setting::from_id_value(0x5, 16384),
Some(Setting::MaxFrameSize(16384))
);
assert_eq!(
Setting::from_id_value(0x6, 8192),
Some(Setting::MaxHeaderListSize(8192))
);
assert_eq!(Setting::from_id_value(0x7, 123), None);
assert_eq!(Setting::from_id_value(0xFF, 456), None);
}
#[test]
fn test_settings_frame_rejects_invalid_enable_push_value() {
let header = FrameHeader {
length: 6,
frame_type: FrameType::Settings as u8,
flags: 0,
stream_id: 0,
};
let payload = Bytes::from_static(&[0x00, 0x02, 0x00, 0x00, 0x00, 0x02]);
let err = SettingsFrame::parse(&header, &payload).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn settings_frame_rejects_invalid_wire_bounds() {
let header = FrameHeader {
length: 6,
frame_type: FrameType::Settings as u8,
flags: 0,
stream_id: 0,
};
let cases = [
(0x04_u16, 0x8000_0000_u32, ErrorCode::FlowControlError),
(
0x05,
MIN_MAX_FRAME_SIZE.saturating_sub(1),
ErrorCode::ProtocolError,
),
(
0x05,
MAX_FRAME_SIZE.saturating_add(1),
ErrorCode::ProtocolError,
),
];
for (id, value, expected_code) in cases {
let mut payload = BytesMut::with_capacity(6);
payload.put_u16(id);
payload.put_u32(value);
let err = SettingsFrame::parse(&header, &payload.freeze()).unwrap_err();
assert_eq!(err.code, expected_code, "id={id:#x} value={value}");
}
}
#[test]
fn settings_frame_encode_rejects_invalid_wire_bounds_without_partial_write() {
let cases = [
(
Setting::InitialWindowSize(MAX_INITIAL_WINDOW_SIZE.saturating_add(1)),
ErrorCode::FlowControlError,
),
(
Setting::MaxFrameSize(MIN_MAX_FRAME_SIZE.saturating_sub(1)),
ErrorCode::ProtocolError,
),
(
Setting::MaxFrameSize(MAX_FRAME_SIZE.saturating_add(1)),
ErrorCode::ProtocolError,
),
];
for (setting, expected_code) in cases {
let mut buf = BytesMut::new();
let err = SettingsFrame::new(vec![setting])
.encode(&mut buf)
.expect_err("invalid setting must not encode");
assert_eq!(err.code, expected_code, "setting={setting:?}");
assert!(buf.is_empty(), "invalid setting must not partially encode");
}
}
#[test]
fn settings_frame_accepts_wire_bounds() {
let header = FrameHeader {
length: 12,
frame_type: FrameType::Settings as u8,
flags: 0,
stream_id: 0,
};
let mut payload = BytesMut::with_capacity(12);
payload.put_u16(0x04);
payload.put_u32(MAX_INITIAL_WINDOW_SIZE);
payload.put_u16(0x05);
payload.put_u32(MAX_FRAME_SIZE);
let parsed = SettingsFrame::parse(&header, &payload.freeze()).unwrap();
assert_eq!(
parsed.settings,
vec![
Setting::InitialWindowSize(MAX_INITIAL_WINDOW_SIZE),
Setting::MaxFrameSize(MAX_FRAME_SIZE)
]
);
}
#[test]
fn test_setting_id_and_value() {
let settings = vec![
Setting::HeaderTableSize(4096),
Setting::EnablePush(true),
Setting::MaxConcurrentStreams(100),
Setting::InitialWindowSize(65535),
Setting::MaxFrameSize(16384),
Setting::MaxHeaderListSize(8192),
];
for setting in settings {
assert_eq!(
Setting::from_id_value(setting.id(), setting.value()),
Some(setting)
);
}
}
#[test]
fn test_window_update_max_increment() {
let original = WindowUpdateFrame::new(0, 0x7FFF_FFFF);
let mut buf = BytesMut::new();
original.encode(&mut buf).expect("encode");
let header = FrameHeader::parse(&mut buf).unwrap();
let payload = buf.split_to(header.length as usize).freeze();
let parsed = WindowUpdateFrame::parse(&header, &payload).unwrap();
assert_eq!(parsed.increment, 0x7FFF_FFFF);
}
#[test]
fn window_update_parse_ignores_reserved_increment_bit() {
let header = FrameHeader {
length: 4,
frame_type: FrameType::WindowUpdate as u8,
flags: 0,
stream_id: 7,
};
for payload in [
Bytes::from_static(&[0x00, 0x00, 0x04, 0x00]),
Bytes::from_static(&[0x80, 0x00, 0x04, 0x00]),
] {
let parsed = WindowUpdateFrame::parse(&header, &payload).unwrap();
assert_eq!(parsed.stream_id, 7);
assert_eq!(parsed.increment, 1024);
}
}
#[test]
fn test_error_code_all_variants() {
let codes = [
(0x0, ErrorCode::NoError),
(0x1, ErrorCode::ProtocolError),
(0x3, ErrorCode::FlowControlError),
(0x4, ErrorCode::SettingsTimeout),
(0x5, ErrorCode::StreamClosed),
(0x6, ErrorCode::FrameSizeError),
(0x7, ErrorCode::RefusedStream),
(0x8, ErrorCode::Cancel),
(0x9, ErrorCode::CompressionError),
(0xa, ErrorCode::ConnectError),
(0xb, ErrorCode::EnhanceYourCalm),
(0xc, ErrorCode::InadequateSecurity),
(0xd, ErrorCode::Http11Required),
];
for (value, expected) in codes {
let code = ErrorCode::from_u32(value);
assert_eq!(code, expected);
assert_eq!(u32::from(code), value);
}
assert_eq!(ErrorCode::from_u32(0xFFFF), ErrorCode::InternalError);
}
#[test]
fn test_partial_header_parse_insufficient_bytes() {
let mut buf = BytesMut::from(&[0, 0, 5, 0, 0, 0, 0][..]);
let result = FrameHeader::parse(&mut buf);
assert!(result.is_err());
}
#[test]
fn frame_type_debug_clone_copy_eq_hash() {
use std::collections::HashSet;
let ft = FrameType::Headers;
let dbg = format!("{ft:?}");
assert!(dbg.contains("Headers"));
let ft2 = ft;
assert_eq!(ft, ft2);
let ft3 = ft;
assert_eq!(ft, ft3);
assert_ne!(FrameType::Data, FrameType::Settings);
let mut set = HashSet::new();
set.insert(FrameType::Data);
set.insert(FrameType::Ping);
assert_eq!(set.len(), 2);
}
#[test]
fn frame_header_debug_clone_copy_eq() {
let fh = FrameHeader {
length: 42,
frame_type: 0x1,
flags: 0x4,
stream_id: 1,
};
let dbg = format!("{fh:?}");
assert!(dbg.contains("FrameHeader"));
let fh2 = fh;
assert_eq!(fh, fh2);
let fh3 = fh;
assert_eq!(fh, fh3);
}
#[test]
fn priority_spec_debug_clone_copy_eq() {
let ps = PrioritySpec {
exclusive: false,
dependency: 0,
weight: 16,
};
let dbg = format!("{ps:?}");
assert!(dbg.contains("PrioritySpec"));
let ps2 = ps;
assert_eq!(ps, ps2);
let ps3 = ps;
assert_eq!(ps, ps3);
}
#[test]
fn settings_frame_debug_clone_default() {
let sf = SettingsFrame::default();
let dbg = format!("{sf:?}");
assert!(dbg.contains("SettingsFrame"));
let sf2 = sf;
assert_eq!(sf2.settings.len(), 0);
assert!(!sf2.ack);
}
#[test]
fn setting_debug_clone_copy_eq() {
let s = Setting::HeaderTableSize(4096);
let dbg = format!("{s:?}");
assert!(dbg.contains("HeaderTableSize"));
assert!(dbg.contains("4096"));
let s2 = s;
assert_eq!(s, s2);
let s3 = s;
assert_eq!(s, s3);
assert_ne!(Setting::EnablePush(true), Setting::EnablePush(false));
}
#[derive(Debug, Clone)]
struct Rfc7540ContinuationTestCase {
id: &'static str,
description: &'static str,
requirement_level: &'static str, test_fn: fn() -> Result<(), H2Error>,
}
fn get_rfc7540_continuation_test_cases() -> Vec<Rfc7540ContinuationTestCase> {
vec![
Rfc7540ContinuationTestCase {
id: "RFC7540-6.10.1",
description: "CONTINUATION frames MUST be sent only when a stream is in open or half-closed (remote) state",
requirement_level: "MUST",
test_fn: test_continuation_valid_stream_state,
},
Rfc7540ContinuationTestCase {
id: "RFC7540-6.10.2",
description: "CONTINUATION frames MUST be associated with a stream (stream ID > 0)",
requirement_level: "MUST",
test_fn: test_continuation_nonzero_stream_id,
},
Rfc7540ContinuationTestCase {
id: "RFC7540-6.10.3",
description: "If HEADERS frame does not have END_HEADERS, CONTINUATION frames MUST follow",
requirement_level: "MUST",
test_fn: test_continuation_follows_headers_without_end_headers,
},
Rfc7540ContinuationTestCase {
id: "RFC7540-6.10.4",
description: "Receiver MUST treat any other frame type between HEADERS and CONTINUATION as connection error",
requirement_level: "MUST",
test_fn: test_continuation_intervening_frames_rejected,
},
Rfc7540ContinuationTestCase {
id: "RFC7540-6.10.5",
description: "Receiver MUST treat CONTINUATION on different stream as connection error",
requirement_level: "MUST",
test_fn: test_continuation_different_stream_rejected,
},
]
}
fn test_continuation_valid_stream_state() -> Result<(), H2Error> {
let header = FrameHeader {
length: 0,
frame_type: FrameType::Continuation as u8,
flags: continuation_flags::END_HEADERS,
stream_id: 1, };
let payload = Bytes::new();
let frame = ContinuationFrame::parse(&header, payload)?;
assert_eq!(frame.stream_id, 1);
assert!(frame.end_headers);
Ok(())
}
fn test_continuation_nonzero_stream_id() -> Result<(), H2Error> {
let header = FrameHeader {
length: 0,
frame_type: FrameType::Continuation as u8,
flags: 0,
stream_id: 0, };
let payload = Bytes::new();
match ContinuationFrame::parse(&header, payload) {
Err(err) if err.code == crate::http::h2::error::ErrorCode::ProtocolError => Ok(()),
Err(other) => Err(other),
Ok(_) => Err(H2Error::protocol(
"CONTINUATION with stream ID 0 should be rejected",
)),
}
}
fn test_continuation_follows_headers_without_end_headers() -> Result<(), H2Error> {
let large_header_block = vec![0u8; 32_000];
let headers_frame = HeadersFrame {
stream_id: 1,
header_block: Bytes::from(large_header_block[..16_000].to_vec()),
end_stream: false,
end_headers: false, priority: None,
};
let continuation_frame = ContinuationFrame {
stream_id: 1,
header_block: Bytes::from(large_header_block[16_000..].to_vec()),
end_headers: true, };
let mut headers_buf = BytesMut::new();
headers_frame.encode(&mut headers_buf)?;
let mut continuation_buf = BytesMut::new();
continuation_frame.encode(&mut continuation_buf)?;
let total_size = headers_frame.header_block.len() + continuation_frame.header_block.len();
assert_eq!(total_size, 32_000);
Ok(())
}
fn test_continuation_intervening_frames_rejected() -> Result<(), H2Error> {
let headers_without_end = FrameHeader {
length: 100,
frame_type: FrameType::Headers as u8,
flags: 0, stream_id: 1,
};
let intervening_data = FrameHeader {
length: 10,
frame_type: FrameType::Data as u8,
flags: 0,
stream_id: 1, };
let continuation_after = FrameHeader {
length: 50,
frame_type: FrameType::Continuation as u8,
flags: continuation_flags::END_HEADERS,
stream_id: 1,
};
assert!(!headers_without_end.has_flag(headers_flags::END_HEADERS));
assert_eq!(intervening_data.frame_type, FrameType::Data as u8);
let cont_frame = ContinuationFrame::parse(&continuation_after, Bytes::new())?;
assert!(cont_frame.end_headers);
Ok(())
}
fn test_continuation_different_stream_rejected() -> Result<(), H2Error> {
let headers_stream_1 = FrameHeader {
length: 100,
frame_type: FrameType::Headers as u8,
flags: 0, stream_id: 1,
};
let continuation_stream_3 = FrameHeader {
length: 50,
frame_type: FrameType::Continuation as u8,
flags: continuation_flags::END_HEADERS,
stream_id: 3, };
assert!(!headers_stream_1.has_flag(headers_flags::END_HEADERS));
let cont_frame = ContinuationFrame::parse(&continuation_stream_3, Bytes::new())?;
assert_eq!(cont_frame.stream_id, 3);
assert!(cont_frame.end_headers);
assert_ne!(headers_stream_1.stream_id, continuation_stream_3.stream_id);
Ok(())
}
#[test]
fn continuation_frame_large_header_block_golden_test() {
let large_headers = vec![
("x-very-long-header-name-1", "x".repeat(5000)),
("x-very-long-header-name-2", "y".repeat(5000)),
("x-very-long-header-name-3", "z".repeat(5000)),
("authorization", "Bearer ".to_string() + &"t".repeat(2000)),
(
"user-agent",
"Mozilla/5.0 ".to_string() + &"(details; ".repeat(500) + ")",
),
];
let mut total_size = 0;
for (name, value) in &large_headers {
total_size += name.len() + value.len() + 32; }
assert!(total_size > DEFAULT_MAX_FRAME_SIZE as usize);
let chunk_size = DEFAULT_MAX_FRAME_SIZE as usize;
let chunks_needed = (total_size + chunk_size - 1) / chunk_size;
assert!(
chunks_needed >= 2,
"Large header set should require at least 2 frames"
);
let headers_frame = HeadersFrame {
stream_id: 5,
header_block: Bytes::from(vec![0u8; chunk_size]),
end_stream: false,
end_headers: false, priority: None,
};
let mut continuation_frames = Vec::new();
for i in 1..chunks_needed {
let is_last = i == chunks_needed - 1;
let remaining_size = if is_last {
total_size - (i * chunk_size)
} else {
chunk_size
};
continuation_frames.push(ContinuationFrame {
stream_id: 5,
header_block: Bytes::from(vec![0u8; remaining_size]),
end_headers: is_last, });
}
assert!(!headers_frame.end_headers);
assert_eq!(continuation_frames.len(), chunks_needed - 1);
assert!(continuation_frames.last().unwrap().end_headers);
for frame in &continuation_frames {
assert_eq!(frame.stream_id, headers_frame.stream_id);
}
eprintln!(
"✓ Large header block CONTINUATION scenario: {} chunks for {} bytes",
chunks_needed, total_size
);
}
#[test]
fn rfc7540_continuation_conformance() {
let test_cases = get_rfc7540_continuation_test_cases();
let mut pass_count = 0;
let mut fail_count = 0;
for test_case in test_cases {
match (test_case.test_fn)() {
Ok(()) => {
pass_count += 1;
eprintln!(
"✓ {} [{}]: {}",
test_case.id, test_case.requirement_level, test_case.description
);
}
Err(e) => {
fail_count += 1;
eprintln!(
"✗ {} [{}]: {} - Error: {}",
test_case.id, test_case.requirement_level, test_case.description, e
);
}
}
}
eprintln!(
"RFC 7540 CONTINUATION Conformance: {}/{} tests passed",
pass_count,
pass_count + fail_count
);
assert_eq!(
fail_count, 0,
"{} CONTINUATION conformance tests failed",
fail_count
);
}
#[test]
fn continuation_frame_encoding_roundtrip() {
let original = ContinuationFrame {
stream_id: 42,
header_block: Bytes::from("test header block data"),
end_headers: true,
};
let mut buffer = BytesMut::new();
original.encode(&mut buffer).expect("encode should succeed");
let header = FrameHeader::parse(&mut buffer).expect("header parse should succeed");
assert_eq!(header.frame_type, FrameType::Continuation as u8);
assert_eq!(header.stream_id, 42);
assert!(header.has_flag(continuation_flags::END_HEADERS));
let payload = buffer.split_to(header.length as usize).freeze();
let parsed = ContinuationFrame::parse(&header, payload).expect("parse should succeed");
assert_eq!(parsed.stream_id, original.stream_id);
assert_eq!(parsed.header_block, original.header_block);
assert_eq!(parsed.end_headers, original.end_headers);
eprintln!("✓ CONTINUATION frame round-trip test passed");
}
}