use std::collections::HashMap;
use crate::{
frame::*,
state::{
H2ConnectionState,
ParseError,
ParseErrorKind,
ParsedH2Message,
StreamId,
StreamPhase,
StreamState,
TimestampNs,
},
};
pub(crate) fn parse_frames_stateful(
buffer: &[u8],
state: &mut H2ConnectionState,
) -> Result<HashMap<StreamId, ParsedH2Message>, ParseError> {
let mut pos = 0;
let mut completed_messages = HashMap::new();
let timestamp_ns = state.current_timestamp_ns;
if !state.preface_received && buffer.starts_with(CONNECTION_PREFACE) {
pos += CONNECTION_PREFACE.len();
state.preface_received = true;
}
while pos + FRAME_HEADER_SIZE <= buffer.len() {
let header = parse_frame_header(&buffer[pos..])?;
if header.length > state.settings.max_frame_size {
return Err(ParseError::new(ParseErrorKind::Http2FrameSizeError));
}
let frame_total_size = FRAME_HEADER_SIZE
.checked_add(header.length as usize)
.ok_or(ParseError::new(ParseErrorKind::Http2InvalidFrame))?;
if pos + frame_total_size > buffer.len() {
break; }
if let Some(expected_stream) = state.expecting_continuation
&& (header.frame_type != FRAME_TYPE_CONTINUATION || header.stream_id != expected_stream)
{
return Err(ParseError::new(ParseErrorKind::Http2ContinuationExpected));
}
let frame_payload = &buffer[pos + FRAME_HEADER_SIZE..pos + frame_total_size];
match header.frame_type {
FRAME_TYPE_DATA => {
handle_data_frame(state, &header, frame_payload, timestamp_ns)?;
},
FRAME_TYPE_HEADERS => {
handle_headers_frame(state, &header, frame_payload, timestamp_ns)?;
},
FRAME_TYPE_CONTINUATION => {
handle_continuation_frame(state, &header, frame_payload)?;
},
FRAME_TYPE_SETTINGS => {
handle_settings_frame(state, &header, frame_payload)?;
},
FRAME_TYPE_RST_STREAM => {
handle_rst_stream(state, &header, frame_payload);
},
FRAME_TYPE_GOAWAY => {
handle_goaway(frame_payload);
},
FRAME_TYPE_PRIORITY
| FRAME_TYPE_PUSH_PROMISE
| FRAME_TYPE_PING
| FRAME_TYPE_WINDOW_UPDATE => {},
_ => {},
}
pos += frame_total_size;
if header.stream_id != StreamId(0)
&& let Some((id, msg)) = check_stream_completion(state, header.stream_id)
{
completed_messages.insert(id, msg);
}
}
state.evict_stale_streams(timestamp_ns);
Ok(completed_messages)
}
pub(crate) fn parse_buffer_incremental(state: &mut H2ConnectionState) -> Result<(), ParseError> {
let mut pos = 0;
let timestamp_ns = state.current_timestamp_ns;
let mut fatal_error: Option<ParseError> = None;
if !state.preface_received && state.buffer.starts_with(CONNECTION_PREFACE) {
pos += CONNECTION_PREFACE.len();
state.preface_received = true;
}
while pos + FRAME_HEADER_SIZE <= state.buffer.len() {
let header = match parse_frame_header(&state.buffer[pos..]) {
Ok(h) => h,
Err(_) => break,
};
let frame_total_size = match FRAME_HEADER_SIZE.checked_add(header.length as usize) {
Some(s) => s,
None => break,
};
if pos + frame_total_size > state.buffer.len() {
break; }
if let Some(expected_stream) = state.expecting_continuation
&& (header.frame_type != FRAME_TYPE_CONTINUATION || header.stream_id != expected_stream)
{
crate::trace_warn!(
"expected CONTINUATION for stream {expected_stream}, got frame type {} on stream \
{}; abandoning incomplete header block",
header.frame_type,
header.stream_id
);
state.expecting_continuation = None;
state.active_streams.remove(&expected_stream);
}
let frame_payload = state.buffer[pos + FRAME_HEADER_SIZE..pos + frame_total_size].to_vec();
let result = match header.frame_type {
FRAME_TYPE_DATA => handle_data_frame(state, &header, &frame_payload, timestamp_ns),
FRAME_TYPE_HEADERS => {
handle_headers_frame(state, &header, &frame_payload, timestamp_ns)
},
FRAME_TYPE_CONTINUATION => handle_continuation_frame(state, &header, &frame_payload),
FRAME_TYPE_SETTINGS => handle_settings_frame(state, &header, &frame_payload),
FRAME_TYPE_RST_STREAM => {
handle_rst_stream(state, &header, &frame_payload);
Ok(())
},
FRAME_TYPE_GOAWAY => {
handle_goaway(&frame_payload);
Ok(())
},
FRAME_TYPE_PRIORITY
| FRAME_TYPE_PUSH_PROMISE
| FRAME_TYPE_PING
| FRAME_TYPE_WINDOW_UPDATE => Ok(()),
_ => Ok(()),
};
pos += frame_total_size;
let stream_id = header.stream_id;
if let Err(ref e) = result {
if matches!(e.kind, ParseErrorKind::Http2HpackError(_)) {
crate::trace_warn!("fatal HPACK error on stream {stream_id}: {e}");
fatal_error = Some(result.unwrap_err());
break;
}
crate::trace_warn!("non-fatal frame error on stream {stream_id}: {e}");
continue;
}
if stream_id != StreamId(0)
&& let Some(pair) = check_stream_completion(state, stream_id)
{
state.completed.push_back(pair);
}
}
if pos > 0 {
state.buffer.drain(..pos);
}
state.evict_stale_streams(timestamp_ns);
match fatal_error {
Some(e) => Err(e),
None => Ok(()),
}
}
fn handle_headers_frame(
state: &mut H2ConnectionState,
header: &FrameHeader,
payload: &[u8],
timestamp_ns: TimestampNs,
) -> Result<(), ParseError> {
let stream_id = header.stream_id;
if !state.active_streams.contains_key(&stream_id)
&& state.active_streams.len() >= state.limits.max_concurrent_streams
{
crate::trace_warn!("max concurrent streams reached, rejecting stream {stream_id}");
return Err(ParseError::with_stream(
ParseErrorKind::Http2MaxConcurrentStreams,
stream_id,
));
}
if !state.active_streams.contains_key(&stream_id) {
if stream_id.0 != 0 && stream_id <= state.highest_stream_id {
crate::trace_warn!(
"stream {stream_id} not greater than highest seen ({}); RFC 7540 §5.1.1 violation",
state.highest_stream_id
);
}
if stream_id.0.is_multiple_of(2) && stream_id.0 != 0 {
crate::trace_warn!(
"even stream ID {stream_id} (server-initiated); unexpected for client traffic"
);
}
}
let stream = state.active_streams.entry(stream_id).or_insert_with(|| {
if stream_id > state.highest_stream_id {
state.highest_stream_id = stream_id;
}
StreamState::new(stream_id, timestamp_ns)
});
let (header_block, _padding_len) = if header.flags & FLAG_PADDED != 0 {
if payload.is_empty() {
return Err(ParseError::with_stream(
ParseErrorKind::Http2PaddingError,
stream_id,
));
}
let pad_len = payload[0] as usize;
if pad_len >= payload.len() {
return Err(ParseError::with_stream(
ParseErrorKind::Http2PaddingError,
stream_id,
));
}
(&payload[1..payload.len() - pad_len], pad_len)
} else {
(payload, 0)
};
let header_block = if header.flags & FLAG_PRIORITY != 0 {
if header_block.len() < 5 {
return Err(ParseError::with_stream(
ParseErrorKind::Http2PriorityError,
stream_id,
));
}
let dependency = u32::from_be_bytes([
header_block[0] & 0x7F,
header_block[1],
header_block[2],
header_block[3],
]);
if dependency == stream_id.0 {
crate::trace_warn!("stream {stream_id} depends on itself (RFC 7540 §5.3.1 violation)");
}
&header_block[5..]
} else {
header_block
};
stream.header_size += FRAME_HEADER_SIZE + payload.len();
let has_end_headers = header.flags & FLAG_END_HEADERS != 0;
let has_end_stream = header.flags & FLAG_END_STREAM != 0;
if has_end_headers {
let full_block: Vec<u8> = if stream.continuation_buffer.is_empty() {
header_block.to_vec()
} else {
stream.continuation_buffer.extend_from_slice(header_block);
std::mem::take(&mut stream.continuation_buffer)
};
decode_headers_into_stream(&mut state.decoder, stream, &full_block, &state.limits)?;
state.expecting_continuation = None;
} else {
stream.continuation_buffer.extend_from_slice(header_block);
state.expecting_continuation = Some(stream_id);
}
if has_end_stream {
stream.end_stream_timestamp_ns = timestamp_ns;
}
stream.phase = match (has_end_headers, has_end_stream) {
(true, true) => StreamPhase::Complete,
(true, false) => StreamPhase::ReceivingBody,
(false, es) => StreamPhase::ReceivingHeaders {
end_stream_seen: es,
},
};
Ok(())
}
fn handle_continuation_frame(
state: &mut H2ConnectionState,
header: &FrameHeader,
payload: &[u8],
) -> Result<(), ParseError> {
let stream = state
.active_streams
.get_mut(&header.stream_id)
.ok_or(ParseError::with_stream(
ParseErrorKind::Http2HeadersIncomplete,
header.stream_id,
))?;
stream.continuation_buffer.extend_from_slice(payload);
stream.header_size += FRAME_HEADER_SIZE + payload.len();
if header.flags & FLAG_END_HEADERS != 0 {
let buf = std::mem::take(&mut stream.continuation_buffer);
decode_headers_into_stream(&mut state.decoder, stream, &buf, &state.limits)?;
state.expecting_continuation = None;
stream.phase = match stream.phase {
StreamPhase::ReceivingHeaders {
end_stream_seen: true,
} => StreamPhase::Complete,
_ => StreamPhase::ReceivingBody,
};
}
Ok(())
}
fn handle_data_frame(
state: &mut H2ConnectionState,
header: &FrameHeader,
payload: &[u8],
timestamp_ns: TimestampNs,
) -> Result<(), ParseError> {
let data = if header.flags & FLAG_PADDED != 0 {
if payload.is_empty() {
return Err(ParseError::with_stream(
ParseErrorKind::Http2PaddingError,
header.stream_id,
));
}
let pad_len = payload[0] as usize;
if pad_len >= payload.len() {
return Err(ParseError::with_stream(
ParseErrorKind::Http2PaddingError,
header.stream_id,
));
}
&payload[1..payload.len() - pad_len]
} else {
payload
};
{
let stream = state
.active_streams
.get(&header.stream_id)
.ok_or(ParseError::with_stream(
ParseErrorKind::Http2StreamNotFound,
header.stream_id,
))?;
if matches!(stream.phase, StreamPhase::ReceivingHeaders { .. }) {
crate::trace_warn!(
"DATA on stream {} before headers complete (RFC 7540 §8.1)",
header.stream_id
);
}
if stream.body.len() + data.len() > state.limits.max_body_size {
crate::trace_warn!(
"body size limit exceeded on stream {}, dropping stream",
header.stream_id
);
state.active_streams.remove(&header.stream_id);
return Ok(());
}
}
let stream = state
.active_streams
.get_mut(&header.stream_id)
.ok_or(ParseError::with_stream(
ParseErrorKind::Http2StreamNotFound,
header.stream_id,
))?;
stream.body.extend_from_slice(data);
if header.flags & FLAG_END_STREAM != 0 {
stream.end_stream_timestamp_ns = timestamp_ns;
stream.phase = StreamPhase::Complete;
}
Ok(())
}
fn handle_settings_frame(
state: &mut H2ConnectionState,
_header: &FrameHeader,
payload: &[u8],
) -> Result<(), ParseError> {
if !payload.len().is_multiple_of(6) {
return Err(ParseError::new(ParseErrorKind::Http2SettingsLengthError));
}
let mut pos = 0;
while pos + 6 <= payload.len() {
let setting_id = u16::from_be_bytes([payload[pos], payload[pos + 1]]);
let value = u32::from_be_bytes([
payload[pos + 2],
payload[pos + 3],
payload[pos + 4],
payload[pos + 5],
]);
match setting_id {
0x01 => {
let capped = value.min(state.limits.max_table_size as u32);
state.settings.header_table_size = capped;
state.decoder.set_max_table_size(capped as usize);
},
0x02 => state.settings.enable_push = value != 0,
0x03 => state.settings.max_concurrent_streams = value,
0x04 => state.settings.initial_window_size = value,
0x05 => {
if (16_384..=MAX_FRAME_PAYLOAD_LENGTH).contains(&value) {
state.settings.max_frame_size = value;
}
},
0x06 => state.settings.max_header_list_size = value,
_ => {}, }
pos += 6;
}
Ok(())
}
fn handle_rst_stream(state: &mut H2ConnectionState, header: &FrameHeader, payload: &[u8]) {
if payload.len() < 4 {
crate::trace_warn!(
"RST_STREAM on stream {} with short payload ({} bytes)",
header.stream_id,
payload.len()
);
return;
}
let _error_code = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
crate::trace_warn!(
"RST_STREAM on stream {} error_code={_error_code}",
header.stream_id
);
state.active_streams.remove(&header.stream_id);
}
fn handle_goaway(payload: &[u8]) {
if payload.len() < 8 {
crate::trace_warn!("GOAWAY with short payload ({} bytes)", payload.len());
return;
}
let _last_stream_id =
u32::from_be_bytes([payload[0] & 0x7F, payload[1], payload[2], payload[3]]);
let _error_code = u32::from_be_bytes([payload[4], payload[5], payload[6], payload[7]]);
crate::trace_warn!("GOAWAY: last_stream_id={_last_stream_id}, error_code={_error_code}");
}
fn decode_headers_into_stream(
decoder: &mut loona_hpack::Decoder<'static>,
stream: &mut StreamState,
header_block: &[u8],
limits: &crate::state::H2Limits,
) -> Result<(), ParseError> {
let mut total_size: usize = 0;
let mut header_count: usize = 0;
let mut limit_exceeded = false;
let mut encoding_error = false;
decoder
.decode_with_cb(header_block, |name, value| {
if limit_exceeded || encoding_error {
return;
}
header_count += 1;
if header_count > limits.max_header_count {
limit_exceeded = true;
return;
}
if value.len() > limits.max_header_value_size {
limit_exceeded = true;
return;
}
total_size += name.len() + value.len() + 32;
if total_size > limits.max_header_list_size {
limit_exceeded = true;
return;
}
let (Ok(name_str), Ok(value_str)) =
(std::str::from_utf8(&name), std::str::from_utf8(&value))
else {
encoding_error = true;
return;
};
let name_str = name_str.to_string();
let value_str = value_str.to_string();
match name_str.as_str() {
":method" => stream.method = Some(value_str),
":path" => stream.path = Some(value_str),
":authority" => stream.authority = Some(value_str),
":scheme" => stream.scheme = Some(value_str),
":status" => stream.status = value_str.parse().ok(),
_ => stream.headers.push((name_str, value_str)),
}
})
.map_err(|e| ParseError::new(ParseErrorKind::Http2HpackError(format!("{e:?}"))))?;
if encoding_error {
crate::trace_warn!(
"HPACK decoded header with invalid UTF-8; dynamic table may contain tainted entry"
);
return Err(ParseError::new(ParseErrorKind::Http2InvalidHeaderEncoding));
}
if limit_exceeded {
crate::trace_warn!("HPACK header list size/count limit exceeded");
return Err(ParseError::new(ParseErrorKind::Http2HeaderListTooLarge));
}
Ok(())
}
fn build_parsed_message_owned(stream_id: StreamId, stream: StreamState) -> ParsedH2Message {
ParsedH2Message {
method: stream.method,
path: stream.path,
authority: stream.authority,
scheme: stream.scheme,
status: stream.status,
headers: stream.headers,
stream_id,
header_size: stream.header_size,
body: stream.body,
first_frame_timestamp_ns: stream.first_frame_timestamp_ns,
end_stream_timestamp_ns: stream.end_stream_timestamp_ns,
}
}
fn check_stream_completion(
state: &mut H2ConnectionState,
stream_id: StreamId,
) -> Option<(StreamId, ParsedH2Message)> {
let stream = state.active_streams.get(&stream_id)?;
if stream.phase == StreamPhase::Complete {
let stream = state.active_streams.remove(&stream_id)?;
Some((stream_id, build_parsed_message_owned(stream_id, stream)))
} else {
None
}
}