use cookie_factory::{
GenError,
bytes::{be_u8, be_u16, be_u24, be_u32},
combinator::slice,
r#gen,
sequence::tuple,
};
use crate::protocol::mux::{
h2::H2Settings,
parser::{self, FrameHeader, FrameType, H2Error},
};
pub const H2_PRI: &str = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
pub const SETTINGS_ACKNOWLEDGEMENT: [u8; 9] = [0, 0, 0, 4, 1, 0, 0, 0, 0];
pub const PING_ACKNOWLEDGEMENT_HEADER: [u8; 9] = [0, 0, 8, 6, 1, 0, 0, 0, 0];
pub fn gen_frame_header<'a>(
buf: &'a mut [u8],
frame: &FrameHeader,
) -> Result<(&'a mut [u8], usize), GenError> {
let serializer = tuple((
be_u24(frame.payload_len),
be_u8(serialize_frame_type(&frame.frame_type)),
be_u8(frame.flags),
be_u32(frame.stream_id & parser::STREAM_ID_MASK),
));
r#gen(serializer, buf).map(|(buf, size)| (buf, size as usize))
}
pub fn serialize_frame_type(f: &FrameType) -> u8 {
match *f {
FrameType::Data => 0,
FrameType::Headers => 1,
FrameType::Priority => 2,
FrameType::RstStream => 3,
FrameType::Settings => 4,
FrameType::PushPromise => 5,
FrameType::Ping => 6,
FrameType::GoAway => 7,
FrameType::WindowUpdate => 8,
FrameType::Continuation => 9,
FrameType::PriorityUpdate => 0x10,
FrameType::Unknown(t) => t,
}
}
pub fn gen_ping_acknowledgement<'a>(
buf: &'a mut [u8],
payload: &[u8],
) -> Result<(&'a mut [u8], usize), GenError> {
r#gen(
tuple((slice(PING_ACKNOWLEDGEMENT_HEADER), slice(payload))),
buf,
)
.map(|(buf, size)| (buf, size as usize))
}
pub fn gen_settings<'a>(
buf: &'a mut [u8],
settings: &H2Settings,
) -> Result<(&'a mut [u8], usize), GenError> {
gen_frame_header(
buf,
&FrameHeader {
payload_len: parser::SETTINGS_ENTRY_SIZE * parser::SETTINGS_COUNT,
frame_type: FrameType::Settings,
flags: 0,
stream_id: 0,
},
)
.and_then(|(buf, old_size)| {
r#gen(
tuple((
be_u16(parser::SETTINGS_HEADER_TABLE_SIZE),
be_u32(settings.settings_header_table_size),
be_u16(parser::SETTINGS_ENABLE_PUSH),
be_u32(settings.settings_enable_push as u32),
be_u16(parser::SETTINGS_MAX_CONCURRENT_STREAMS),
be_u32(settings.settings_max_concurrent_streams),
be_u16(parser::SETTINGS_INITIAL_WINDOW_SIZE),
be_u32(settings.settings_initial_window_size),
be_u16(parser::SETTINGS_MAX_FRAME_SIZE),
be_u32(settings.settings_max_frame_size),
be_u16(parser::SETTINGS_MAX_HEADER_LIST_SIZE),
be_u32(settings.settings_max_header_list_size),
be_u16(parser::SETTINGS_ENABLE_CONNECT_PROTOCOL),
be_u32(settings.settings_enable_connect_protocol as u32),
be_u16(parser::SETTINGS_NO_RFC7540_PRIORITIES),
be_u32(settings.settings_no_rfc7540_priorities as u32),
)),
buf,
)
.map(|(buf, size)| (buf, (old_size + size as usize)))
})
}
fn gen_control_frame<F>(
buf: &mut [u8],
header: FrameHeader,
payload: F,
) -> Result<(&mut [u8], usize), GenError>
where
F: FnOnce(&mut [u8]) -> Result<(&mut [u8], u64), GenError>,
{
gen_frame_header(buf, &header)
.and_then(|(buf, old_size)| payload(buf).map(|(buf, size)| (buf, old_size + size as usize)))
}
pub fn gen_rst_stream(
buf: &mut [u8],
stream_id: u32,
error_code: H2Error,
) -> Result<(&mut [u8], usize), GenError> {
gen_control_frame(
buf,
FrameHeader {
payload_len: parser::RST_STREAM_PAYLOAD_SIZE,
frame_type: FrameType::RstStream,
flags: 0,
stream_id,
},
|buf| r#gen(be_u32(error_code as u32), buf),
)
}
pub fn gen_window_update(
buf: &mut [u8],
stream_id: u32,
increment: u32,
) -> Result<(&mut [u8], usize), GenError> {
gen_control_frame(
buf,
FrameHeader {
payload_len: parser::WINDOW_UPDATE_PAYLOAD_SIZE,
frame_type: FrameType::WindowUpdate,
flags: 0,
stream_id,
},
|buf| r#gen(be_u32(increment & parser::STREAM_ID_MASK), buf),
)
}
pub fn gen_goaway(
buf: &mut [u8],
last_stream_id: u32,
error_code: H2Error,
) -> Result<(&mut [u8], usize), GenError> {
gen_control_frame(
buf,
FrameHeader {
payload_len: parser::GOAWAY_PAYLOAD_SIZE,
frame_type: FrameType::GoAway,
flags: 0,
stream_id: 0,
},
|buf| {
r#gen(
tuple((
be_u32(last_stream_id & parser::STREAM_ID_MASK),
be_u32(error_code as u32),
)),
buf,
)
},
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::mux::parser;
fn serialize_header(header: &parser::FrameHeader) -> ([u8; 9], usize) {
let mut buf = [0u8; 9];
let (_, sz) = gen_frame_header(&mut buf[..], header).expect("serialization should succeed");
(buf, sz)
}
fn parse_header(buf: &[u8]) -> parser::FrameHeader {
let (remaining, header) =
parser::frame_header(buf, 16_777_215).expect("parsing should succeed");
assert!(remaining.is_empty(), "all bytes should be consumed");
header
}
#[test]
fn roundtrip_data_frame_header() {
let original = parser::FrameHeader {
payload_len: 100,
frame_type: parser::FrameType::Data,
flags: 0x1, stream_id: 1,
};
let (buf, sz) = serialize_header(&original);
assert_eq!(sz, 9, "frame header is always 9 bytes");
let parsed = parse_header(&buf);
assert_eq!(parsed, original);
}
#[test]
fn roundtrip_headers_frame_header() {
let original = parser::FrameHeader {
payload_len: 256,
frame_type: parser::FrameType::Headers,
flags: 0x25, stream_id: 3,
};
let (buf, sz) = serialize_header(&original);
assert_eq!(sz, 9);
let parsed = parse_header(&buf);
assert_eq!(parsed, original);
}
#[test]
fn roundtrip_settings_frame_header() {
let original = parser::FrameHeader {
payload_len: 36,
frame_type: parser::FrameType::Settings,
flags: 0x0,
stream_id: 0,
};
let (buf, sz) = serialize_header(&original);
assert_eq!(sz, 9);
let parsed = parse_header(&buf);
assert_eq!(parsed, original);
}
#[test]
fn roundtrip_settings_ack_header() {
let original = parser::FrameHeader {
payload_len: 0,
frame_type: parser::FrameType::Settings,
flags: 0x1, stream_id: 0,
};
let (buf, sz) = serialize_header(&original);
assert_eq!(sz, 9);
let parsed = parse_header(&buf);
assert_eq!(parsed, original);
}
#[test]
fn roundtrip_rst_stream_header() {
let original = parser::FrameHeader {
payload_len: 4,
frame_type: parser::FrameType::RstStream,
flags: 0x0,
stream_id: 7,
};
let (buf, sz) = serialize_header(&original);
assert_eq!(sz, 9);
let parsed = parse_header(&buf);
assert_eq!(parsed, original);
}
#[test]
fn roundtrip_window_update_header() {
let original = parser::FrameHeader {
payload_len: 4,
frame_type: parser::FrameType::WindowUpdate,
flags: 0x0,
stream_id: 0,
};
let (buf, sz) = serialize_header(&original);
assert_eq!(sz, 9);
let parsed = parse_header(&buf);
assert_eq!(parsed, original);
}
#[test]
fn roundtrip_window_update_stream_level() {
let original = parser::FrameHeader {
payload_len: 4,
frame_type: parser::FrameType::WindowUpdate,
flags: 0x0,
stream_id: 5,
};
let (buf, sz) = serialize_header(&original);
assert_eq!(sz, 9);
let parsed = parse_header(&buf);
assert_eq!(parsed, original);
}
#[test]
fn roundtrip_ping_header() {
let original = parser::FrameHeader {
payload_len: 8,
frame_type: parser::FrameType::Ping,
flags: 0x1, stream_id: 0,
};
let (buf, sz) = serialize_header(&original);
assert_eq!(sz, 9);
let parsed = parse_header(&buf);
assert_eq!(parsed, original);
}
#[test]
fn roundtrip_goaway_header() {
let original = parser::FrameHeader {
payload_len: 8,
frame_type: parser::FrameType::GoAway,
flags: 0x0,
stream_id: 0,
};
let (buf, sz) = serialize_header(&original);
assert_eq!(sz, 9);
let parsed = parse_header(&buf);
assert_eq!(parsed, original);
}
#[test]
fn roundtrip_continuation_header() {
let original = parser::FrameHeader {
payload_len: 128,
frame_type: parser::FrameType::Continuation,
flags: 0x4, stream_id: 9,
};
let (buf, sz) = serialize_header(&original);
assert_eq!(sz, 9);
let parsed = parse_header(&buf);
assert_eq!(parsed, original);
}
#[test]
fn roundtrip_all_frame_types() {
let frame_types = [
(parser::FrameType::Data, 0u8),
(parser::FrameType::Headers, 1),
(parser::FrameType::Priority, 2),
(parser::FrameType::RstStream, 3),
(parser::FrameType::Settings, 4),
(parser::FrameType::PushPromise, 5),
(parser::FrameType::Ping, 6),
(parser::FrameType::GoAway, 7),
(parser::FrameType::WindowUpdate, 8),
(parser::FrameType::Continuation, 9),
];
for (ft, expected_byte) in &frame_types {
assert_eq!(
serialize_frame_type(ft),
*expected_byte,
"serialize_frame_type mismatch for {ft:?}"
);
let stream_id = match ft {
parser::FrameType::Settings
| parser::FrameType::Ping
| parser::FrameType::GoAway => 0,
_ => 1,
};
let header = parser::FrameHeader {
payload_len: 0,
frame_type: ft.to_owned(),
flags: 0,
stream_id,
};
let (buf, sz) = serialize_header(&header);
assert_eq!(sz, 9);
let parsed = parse_header(&buf);
assert_eq!(parsed.frame_type, *ft, "round-trip failed for {ft:?}");
}
}
#[test]
fn roundtrip_large_payload_len() {
let original = parser::FrameHeader {
payload_len: 16_777_215,
frame_type: parser::FrameType::Data,
flags: 0x0,
stream_id: 1,
};
let (buf, sz) = serialize_header(&original);
assert_eq!(sz, 9);
let parsed = parse_header(&buf);
assert_eq!(parsed, original);
}
#[test]
fn roundtrip_zero_payload_len() {
let original = parser::FrameHeader {
payload_len: 0,
frame_type: parser::FrameType::Data,
flags: 0x0,
stream_id: 1,
};
let (buf, sz) = serialize_header(&original);
assert_eq!(sz, 9);
let parsed = parse_header(&buf);
assert_eq!(parsed, original);
}
#[test]
fn roundtrip_max_stream_id() {
let original = parser::FrameHeader {
payload_len: 4,
frame_type: parser::FrameType::WindowUpdate,
flags: 0x0,
stream_id: 0x7FFF_FFFF,
};
let (buf, sz) = serialize_header(&original);
assert_eq!(sz, 9);
let parsed = parse_header(&buf);
assert_eq!(parsed, original);
}
#[test]
fn roundtrip_all_flags_set() {
let original = parser::FrameHeader {
payload_len: 50,
frame_type: parser::FrameType::Headers,
flags: 0xFF,
stream_id: 1,
};
let (buf, sz) = serialize_header(&original);
assert_eq!(sz, 9);
let parsed = parse_header(&buf);
assert_eq!(parsed, original);
}
#[test]
fn buffer_too_small_for_frame_header() {
let header = parser::FrameHeader {
payload_len: 10,
frame_type: parser::FrameType::Data,
flags: 0,
stream_id: 1,
};
let mut buf = [0u8; 8]; let result = gen_frame_header(&mut buf[..], &header);
assert!(result.is_err(), "should fail with buffer too small");
}
#[test]
fn serialized_bytes_match_expected_layout() {
let header = parser::FrameHeader {
payload_len: 0x000102, frame_type: parser::FrameType::Settings, flags: 0xAB,
stream_id: 0x0304_0506,
};
let (buf, _) = serialize_header(&header);
assert_eq!(buf[0], 0x00);
assert_eq!(buf[1], 0x01);
assert_eq!(buf[2], 0x02);
assert_eq!(buf[3], 0x04); assert_eq!(buf[4], 0xAB);
assert_eq!(buf[5], 0x03);
assert_eq!(buf[6], 0x04);
assert_eq!(buf[7], 0x05);
assert_eq!(buf[8], 0x06);
}
#[test]
fn gen_frame_header_masks_reserved_bit() {
let original = parser::FrameHeader {
payload_len: 0,
frame_type: parser::FrameType::Ping,
flags: 0,
stream_id: 0xFFFF_FFFF,
};
let (buf, sz) = serialize_header(&original);
assert_eq!(sz, 9);
assert_eq!(buf[5] & 0x80, 0);
let stream_id = u32::from_be_bytes([buf[5], buf[6], buf[7], buf[8]]);
assert_eq!(stream_id, 0x7FFF_FFFF);
}
#[test]
fn gen_goaway_masks_reserved_bit_in_last_stream_id() {
let mut buf = [0u8; 17]; let (_, sz) = gen_goaway(&mut buf[..], 0xFFFF_FFFF, H2Error::ProtocolError)
.expect("serialization should succeed");
assert_eq!(sz, 17);
assert_eq!(buf[9] & 0x80, 0);
let last_stream_id = u32::from_be_bytes([buf[9], buf[10], buf[11], buf[12]]);
assert_eq!(last_stream_id, 0x7FFF_FFFF);
}
}