use crate::error::FrameDecodeError;
use crate::settings::Setting;
use crate::varint::{self, VarInt};
use super::{
DataPayload, Frame, FrameType, GoawayPayload, HeadersPayload, SettingsPayload, UnknownFrame,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FrameHeader {
frame_type: VarInt,
payload_len: VarInt,
header_len: usize,
}
impl FrameHeader {
pub(crate) const fn from_validated_parts(
frame_type: VarInt,
payload_len: VarInt,
header_len: usize,
) -> Self {
debug_assert!(header_len >= frame_type.encoded_len() + payload_len.encoded_len());
debug_assert!(header_len <= 16);
Self {
frame_type,
payload_len,
header_len,
}
}
pub const fn frame_type(&self) -> VarInt {
self.frame_type
}
pub const fn payload_len(&self) -> VarInt {
self.payload_len
}
pub const fn header_len(&self) -> usize {
self.header_len
}
pub fn total_len(&self) -> Option<usize> {
let payload = usize::try_from(self.payload_len.get()).ok()?;
self.header_len.checked_add(payload)
}
}
pub fn decode_frame_header(buf: &[u8]) -> Result<FrameHeader, FrameDecodeError> {
if buf.is_empty() {
return Err(FrameDecodeError::BufferTooShort);
}
let (frame_type, type_len) =
varint::decode(buf).map_err(|_| FrameDecodeError::BufferTooShort)?;
if buf.len() < type_len {
return Err(FrameDecodeError::BufferTooShort);
}
let (payload_len, len_len) =
varint::decode(&buf[type_len..]).map_err(|_| FrameDecodeError::BufferTooShort)?;
if FrameType::is_http2_only(frame_type.get()) {
return Err(FrameDecodeError::Http2Frame(frame_type));
}
Ok(FrameHeader::from_validated_parts(
frame_type,
payload_len,
type_len + len_len,
))
}
pub fn decode_frame(buf: &[u8]) -> Result<(Frame, usize), FrameDecodeError> {
let header = decode_frame_header(buf)?;
let total_len = header.total_len().ok_or(FrameDecodeError::InvalidLength)?;
if buf.len() < total_len {
return Err(FrameDecodeError::BufferTooShort);
}
let payload = &buf[header.header_len()..total_len];
let frame_type_u64 = header.frame_type().get();
let frame = match FrameType::from_type(frame_type_u64) {
Some(FrameType::Data) => Frame::Data(DataPayload::new(payload.to_vec())),
Some(FrameType::Headers) => Frame::Headers(HeadersPayload::new(payload.to_vec())),
Some(FrameType::Settings) => decode_settings_frame(payload)?,
Some(FrameType::Goaway) => decode_goaway_frame(payload)?,
Some(FrameType::MaxPushId) => decode_max_push_id_frame(payload)?,
Some(FrameType::CancelPush | FrameType::PushPromise) => {
return Err(FrameDecodeError::ServerPushNotSupported(
header.frame_type(),
));
}
None => Frame::Unknown(
UnknownFrame::new(header.frame_type(), payload.to_vec())
.expect("None arm receives only unknown non-HTTP/2 frame types"),
),
};
Ok((frame, total_len))
}
fn decode_settings_frame(payload: &[u8]) -> Result<Frame, FrameDecodeError> {
let mut settings = SettingsPayload::new();
let mut offset = 0;
while offset < payload.len() {
let (id, id_len) =
varint::decode(&payload[offset..]).map_err(|_| FrameDecodeError::InvalidLength)?;
offset += id_len;
let (value, value_len) =
varint::decode(&payload[offset..]).map_err(|_| FrameDecodeError::InvalidLength)?;
offset += value_len;
let setting = Setting::from_wire(id, value)?;
settings.add(setting)?;
}
Ok(Frame::Settings(settings))
}
fn decode_goaway_frame(payload: &[u8]) -> Result<Frame, FrameDecodeError> {
let (id, consumed) = varint::decode(payload).map_err(|_| FrameDecodeError::InvalidLength)?;
if consumed != payload.len() {
return Err(FrameDecodeError::InvalidLength);
}
Ok(Frame::Goaway(GoawayPayload::new(id)))
}
fn decode_max_push_id_frame(payload: &[u8]) -> Result<Frame, FrameDecodeError> {
let (id, consumed) = varint::decode(payload).map_err(|_| FrameDecodeError::InvalidLength)?;
if consumed != payload.len() {
return Err(FrameDecodeError::InvalidLength);
}
Ok(Frame::MaxPushId(id))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decode_frame_header() {
let buf = [0x00, 0x05];
let header = decode_frame_header(&buf).unwrap();
assert_eq!(header.frame_type().get(), 0);
assert_eq!(header.payload_len().get(), 5);
assert_eq!(header.header_len(), 2);
let buf = [0x04, 0x0a];
let header = decode_frame_header(&buf).unwrap();
assert_eq!(header.frame_type().get(), 4);
assert_eq!(header.payload_len().get(), 10);
}
#[test]
fn test_decode_frame_header_http2_frame() {
let buf = [0x02, 0x05];
let result = decode_frame_header(&buf);
assert_eq!(
result,
Err(FrameDecodeError::Http2Frame(VarInt::from_static(0x02)))
);
let buf = [0x06, 0x08];
let result = decode_frame_header(&buf);
assert_eq!(
result,
Err(FrameDecodeError::Http2Frame(VarInt::from_static(0x06)))
);
}
#[test]
fn test_decode_settings_frame_http2_setting() {
use crate::settings::SettingError;
use crate::varint::VarInt;
let buf = [0x04, 0x02, 0x02, 0x01];
let result = decode_frame(&buf);
assert_eq!(
result,
Err(FrameDecodeError::InvalidSetting(
SettingError::Http2OnlyId {
id: VarInt::new(0x02).unwrap()
}
))
);
}
#[test]
fn test_decode_frame_buffer_too_short() {
let buf = [0x00]; assert_eq!(decode_frame(&buf), Err(FrameDecodeError::BufferTooShort));
let buf = [0x00, 0x05, 0x01, 0x02]; assert_eq!(decode_frame(&buf), Err(FrameDecodeError::BufferTooShort));
}
#[test]
fn test_frame_header_total_len() {
let header =
FrameHeader::from_validated_parts(VarInt::from_static(0), VarInt::from_static(50), 2);
assert_eq!(header.header_len(), 2);
assert_eq!(header.total_len(), Some(52));
}
#[test]
fn test_decode_frame_handles_non_minimal_varint_encoding() {
let buf = [0x40, 0x00, 0x03, 0xaa, 0xbb, 0xcc];
let header = decode_frame_header(&buf).unwrap();
assert_eq!(header.frame_type().get(), 0); assert_eq!(header.payload_len().get(), 3);
assert_eq!(header.header_len(), 3); assert_eq!(header.total_len(), Some(6));
let (frame, consumed) = decode_frame(&buf).unwrap();
assert_eq!(consumed, 6);
match frame {
Frame::Data(p) => assert_eq!(p.data(), &[0xaa, 0xbb, 0xcc][..]),
other => panic!("expected DATA, got {other:?}"),
}
}
#[test]
fn test_decode_settings_frame_duplicate_id() {
use crate::settings::SettingError;
use crate::varint::VarInt;
let buf = [0x04, 0x04, 0x01, 0x04, 0x01, 0x08];
let result = decode_frame(&buf);
assert_eq!(
result,
Err(FrameDecodeError::InvalidSetting(
SettingError::DuplicateId {
id: VarInt::new(0x01).unwrap()
}
))
);
}
#[test]
fn test_decode_settings_frame_reserved_id() {
use crate::settings::SettingError;
use crate::varint::VarInt;
let buf = [0x04, 0x02, 0x00, 0x00];
let result = decode_frame(&buf);
assert_eq!(
result,
Err(FrameDecodeError::InvalidSetting(SettingError::ReservedId {
id: VarInt::ZERO
}))
);
}
#[test]
fn test_decode_settings_frame_invalid_boolean_ecp() {
use crate::settings::SettingError;
use crate::varint::VarInt;
let buf = [0x04, 0x02, 0x08, 0x02];
let result = decode_frame(&buf);
assert_eq!(
result,
Err(FrameDecodeError::InvalidSetting(
SettingError::InvalidBooleanValue {
id: VarInt::new(0x08).unwrap(),
value: VarInt::new(0x02).unwrap(),
}
))
);
}
#[test]
fn test_decode_settings_frame_invalid_boolean_h3_datagram() {
use crate::settings::SettingError;
use crate::varint::VarInt;
let buf = [0x04, 0x02, 0x33, 0x05];
let result = decode_frame(&buf);
assert_eq!(
result,
Err(FrameDecodeError::InvalidSetting(
SettingError::InvalidBooleanValue {
id: VarInt::new(0x33).unwrap(),
value: VarInt::new(0x05).unwrap(),
}
))
);
}
#[test]
fn test_decode_server_push_frames_not_supported() {
let buf = [0x03, 0x01, 0x00];
let result = decode_frame(&buf);
assert_eq!(
result,
Err(FrameDecodeError::ServerPushNotSupported(
VarInt::from_static(0x03)
))
);
let buf = [0x05, 0x02, 0x00, 0x00];
let result = decode_frame(&buf);
assert_eq!(
result,
Err(FrameDecodeError::ServerPushNotSupported(
VarInt::from_static(0x05)
))
);
}
#[test]
fn test_decode_max_push_id_frame() {
let buf = [0x0d, 0x01, 0x05];
let result = decode_frame(&buf).unwrap();
assert_eq!(result, (Frame::MaxPushId(VarInt::from_static(5)), 3));
}
#[test]
fn test_decode_goaway_frame_empty_payload_is_invalid_length() {
let buf = [0x07, 0x00];
assert_eq!(decode_frame(&buf), Err(FrameDecodeError::InvalidLength));
}
#[test]
fn test_decode_goaway_frame_trailing_bytes_is_invalid_length() {
let buf = [0x07, 0x02, 0x05, 0x99];
assert_eq!(decode_frame(&buf), Err(FrameDecodeError::InvalidLength));
}
}