use std::collections::VecDeque;
use std::time::Instant;
use crate::bytes::{Bytes, BytesMut};
use crate::codec::{Decoder, Encoder};
use super::error::{ErrorCode, H2Error};
use super::frame::{
ContinuationFrame, DataFrame, FRAME_HEADER_SIZE, Frame, FrameHeader, FrameType, GoAwayFrame,
HeadersFrame, PingFrame, PushPromiseFrame, RstStreamFrame, Setting, SettingsFrame,
WindowUpdateFrame, parse_frame,
};
use super::hpack::{self, Header};
use super::settings::Settings;
use super::stream::{Stream, StreamState, StreamStore};
pub const CLIENT_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
pub const DEFAULT_CONNECTION_WINDOW_SIZE: i32 = 65535;
const DEFAULT_RST_STREAM_RATE_LIMIT: u32 = 100;
const DEFAULT_RST_STREAM_RATE_WINDOW_MS: u128 = 30_000;
#[derive(Debug, Clone, Copy)]
pub struct RstStreamRateLimit {
pub max_rst_streams: u32,
pub rst_window_ms: u128,
}
impl Default for RstStreamRateLimit {
fn default() -> Self {
Self {
max_rst_streams: DEFAULT_RST_STREAM_RATE_LIMIT,
rst_window_ms: DEFAULT_RST_STREAM_RATE_WINDOW_MS,
}
}
}
fn wall_clock_now() -> Instant {
Instant::now()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState {
Handshaking,
Open,
Closing,
Closed,
}
#[derive(Debug)]
pub struct FrameCodec {
max_frame_size: u32,
partial_header: Option<FrameHeader>,
}
impl FrameCodec {
#[must_use]
pub fn new() -> Self {
Self {
max_frame_size: super::frame::DEFAULT_MAX_FRAME_SIZE,
partial_header: None,
}
}
pub fn set_max_frame_size(&mut self, size: u32) {
self.max_frame_size = size;
}
}
impl Default for FrameCodec {
fn default() -> Self {
Self::new()
}
}
impl Decoder for FrameCodec {
type Item = Frame;
type Error = H2Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
loop {
let header = if let Some(header) = self.partial_header.take() {
header
} else {
if src.len() < FRAME_HEADER_SIZE {
return Ok(None);
}
FrameHeader::parse(src)?
};
if header.length > self.max_frame_size {
return Err(H2Error::frame_size(format!(
"frame too large: {} > {}",
header.length, self.max_frame_size
)));
}
let payload_len = header.length as usize;
if src.len() < payload_len {
self.partial_header = Some(header);
return Ok(None);
}
let payload = src.split_to(payload_len).freeze();
if FrameType::from_u8(header.frame_type).is_none() {
continue;
}
let frame = parse_frame(&header, payload)?;
return Ok(Some(frame));
}
}
}
impl<T: AsRef<Frame>> Encoder<T> for FrameCodec {
type Error = H2Error;
fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
item.as_ref().encode(dst);
Ok(())
}
}
impl AsRef<Self> for Frame {
fn as_ref(&self) -> &Self {
self
}
}
#[derive(Debug)]
#[allow(missing_docs)]
pub enum PendingOp {
Settings(SettingsFrame),
SettingsAck,
PingAck([u8; 8]),
WindowUpdate { stream_id: u32, increment: u32 },
Headers {
stream_id: u32,
headers: Vec<Header>,
end_stream: bool,
},
Continuation {
stream_id: u32,
header_block: Bytes,
end_headers: bool,
},
Data {
stream_id: u32,
data: Bytes,
end_stream: bool,
},
RstStream {
stream_id: u32,
error_code: ErrorCode,
},
GoAway {
last_stream_id: u32,
error_code: ErrorCode,
debug_data: Bytes,
},
}
#[derive(Debug, Clone, Copy)]
struct PushPromiseAccumulator {
associated_stream_id: u32,
promised_stream_id: u32,
}
#[derive(Debug)]
#[allow(clippy::struct_excessive_bools)]
pub struct Connection {
state: ConnectionState,
is_client: bool,
local_settings: Settings,
remote_settings: Settings,
received_settings: bool,
streams: StreamStore,
hpack_encoder: hpack::Encoder,
hpack_decoder: hpack::Decoder,
send_window: i32,
recv_window: i32,
last_stream_id: u32,
received_goaway_last_stream_id: Option<u32>,
sent_goaway_last_stream_id: Option<u32>,
goaway_received: bool,
goaway_sent: bool,
pending_ops: VecDeque<PendingOp>,
time_getter: fn() -> Instant,
continuation_stream_id: Option<u32>,
continuation_started_at: Option<Instant>,
pending_push_promise: Option<PushPromiseAccumulator>,
rst_rate_limit: RstStreamRateLimit,
rst_stream_count: u32,
rst_stream_window_start: Instant,
}
impl Connection {
#[must_use]
pub fn client(settings: Settings) -> Self {
Self::client_with_time_getter(settings, wall_clock_now)
}
#[must_use]
pub fn client_with_time_getter(settings: Settings, time_getter: fn() -> Instant) -> Self {
let max_header_list_size = settings.max_header_list_size;
let initial_window = settings.initial_window_size;
let mut decoder = hpack::Decoder::new();
decoder.set_max_header_list_size(max_header_list_size as usize);
Self {
state: ConnectionState::Handshaking,
is_client: true,
local_settings: settings,
remote_settings: Settings::default(),
received_settings: false,
streams: StreamStore::new(true, initial_window, max_header_list_size),
hpack_encoder: hpack::Encoder::new(),
hpack_decoder: decoder,
send_window: DEFAULT_CONNECTION_WINDOW_SIZE,
recv_window: DEFAULT_CONNECTION_WINDOW_SIZE,
last_stream_id: 0,
received_goaway_last_stream_id: None,
sent_goaway_last_stream_id: None,
goaway_received: false,
goaway_sent: false,
pending_ops: VecDeque::new(),
time_getter,
continuation_stream_id: None,
continuation_started_at: None,
pending_push_promise: None,
rst_rate_limit: RstStreamRateLimit::default(),
rst_stream_count: 0,
rst_stream_window_start: time_getter(),
}
}
#[must_use]
pub fn server(settings: Settings) -> Self {
Self::server_with_time_getter(settings, wall_clock_now)
}
#[must_use]
pub fn server_with_time_getter(settings: Settings, time_getter: fn() -> Instant) -> Self {
let max_header_list_size = settings.max_header_list_size;
let initial_window = settings.initial_window_size;
let mut decoder = hpack::Decoder::new();
decoder.set_max_header_list_size(max_header_list_size as usize);
Self {
state: ConnectionState::Handshaking,
is_client: false,
local_settings: settings,
remote_settings: Settings::default(),
received_settings: false,
streams: StreamStore::new(false, initial_window, max_header_list_size),
hpack_encoder: hpack::Encoder::new(),
hpack_decoder: decoder,
send_window: DEFAULT_CONNECTION_WINDOW_SIZE,
recv_window: DEFAULT_CONNECTION_WINDOW_SIZE,
last_stream_id: 0,
received_goaway_last_stream_id: None,
sent_goaway_last_stream_id: None,
goaway_received: false,
goaway_sent: false,
pending_ops: VecDeque::new(),
time_getter,
continuation_stream_id: None,
continuation_started_at: None,
pending_push_promise: None,
rst_rate_limit: RstStreamRateLimit::default(),
rst_stream_count: 0,
rst_stream_window_start: time_getter(),
}
}
#[must_use]
pub fn rst_stream_rate_limit(mut self, limit: RstStreamRateLimit) -> Self {
self.rst_rate_limit = limit;
self
}
#[must_use]
pub fn state(&self) -> ConnectionState {
self.state
}
#[must_use]
pub fn is_client(&self) -> bool {
self.is_client
}
#[must_use]
pub fn local_settings(&self) -> &Settings {
&self.local_settings
}
#[must_use]
pub fn remote_settings(&self) -> &Settings {
&self.remote_settings
}
#[must_use]
pub fn send_window(&self) -> i32 {
self.send_window
}
#[must_use]
pub fn recv_window(&self) -> i32 {
self.recv_window
}
#[must_use]
pub fn stream(&self, id: u32) -> Option<&Stream> {
self.streams.get(id)
}
#[must_use]
pub fn stream_mut(&mut self, id: u32) -> Option<&mut Stream> {
self.streams.get_mut(id)
}
#[must_use]
pub fn goaway_received(&self) -> bool {
self.goaway_received
}
#[must_use]
pub fn is_awaiting_continuation(&self) -> bool {
self.continuation_stream_id.is_some()
}
#[must_use]
pub fn continuation_stream_id(&self) -> Option<u32> {
self.continuation_stream_id
}
pub fn check_continuation_timeout(&mut self) -> Result<(), H2Error> {
if let Some(started_at) = self.continuation_started_at {
let timeout_ms = self.local_settings.continuation_timeout_ms;
let elapsed = (self.time_getter)().saturating_duration_since(started_at);
if elapsed.as_millis() >= u128::from(timeout_ms) {
let stream_id = self.continuation_stream_id.take();
self.continuation_started_at = None;
self.pending_push_promise = None;
return Err(H2Error::protocol(format!(
"CONTINUATION timeout: no END_HEADERS within {timeout_ms}ms for stream {stream_id:?}",
)));
}
}
Ok(())
}
pub fn queue_initial_settings(&mut self) {
let settings = SettingsFrame::new(
self.local_settings
.to_settings_minimal_for_role(self.is_client),
);
self.pending_ops.push_back(PendingOp::Settings(settings));
}
pub fn open_stream(&mut self, headers: Vec<Header>, end_stream: bool) -> Result<u32, H2Error> {
if self.goaway_received || self.goaway_sent {
return Err(H2Error::protocol("cannot open new streams after GOAWAY"));
}
let stream_id = self.streams.allocate_stream_id()?;
let stream = self.streams.get_mut(stream_id).ok_or_else(|| {
H2Error::connection(
ErrorCode::InternalError,
"allocated stream missing from store",
)
})?;
stream.send_headers(end_stream)?;
self.pending_ops.push_back(PendingOp::Headers {
stream_id,
headers,
end_stream,
});
Ok(stream_id)
}
pub fn send_data(
&mut self,
stream_id: u32,
data: Bytes,
end_stream: bool,
) -> Result<(), H2Error> {
let stream = self.streams.get_mut(stream_id).ok_or_else(|| {
H2Error::stream(stream_id, ErrorCode::StreamClosed, "stream not found")
})?;
stream.send_data(end_stream)?;
self.pending_ops.push_back(PendingOp::Data {
stream_id,
data,
end_stream,
});
Ok(())
}
pub fn send_headers(
&mut self,
stream_id: u32,
headers: Vec<Header>,
end_stream: bool,
) -> Result<(), H2Error> {
let stream = self.streams.get_mut(stream_id).ok_or_else(|| {
H2Error::stream(stream_id, ErrorCode::StreamClosed, "stream not found")
})?;
stream.send_headers(end_stream)?;
self.pending_ops.push_back(PendingOp::Headers {
stream_id,
headers,
end_stream,
});
Ok(())
}
pub fn reset_stream(&mut self, stream_id: u32, error_code: ErrorCode) {
if let Some(stream) = self.streams.get_mut(stream_id) {
stream.reset(error_code);
}
self.pending_ops.push_back(PendingOp::RstStream {
stream_id,
error_code,
});
}
pub fn goaway(&mut self, error_code: ErrorCode, debug_data: Bytes) {
if !self.goaway_sent {
self.goaway_sent = true;
self.state = ConnectionState::Closing;
let last_stream_id = self.last_stream_id;
self.sent_goaway_last_stream_id = Some(last_stream_id);
self.pending_ops.push_back(PendingOp::GoAway {
last_stream_id,
error_code,
debug_data,
});
}
}
pub fn process_frame(&mut self, frame: Frame) -> Result<Option<ReceivedFrame>, H2Error> {
self.check_continuation_timeout()?;
if self.pending_ops.len() > 10_000 {
return Err(H2Error::connection(
ErrorCode::EnhanceYourCalm,
"too many pending operations, possible flood attack",
));
}
if let Some(expected_stream) = self.continuation_stream_id {
match &frame {
Frame::Continuation(cont) if cont.stream_id == expected_stream => {
}
_ => {
return Err(H2Error::protocol("expected CONTINUATION frame"));
}
}
}
let result = match frame {
Frame::Data(f) => self.process_data(f),
Frame::Headers(f) => self.process_headers(f),
Frame::Priority(f) => {
if let Some(stream) = self.streams.get_mut(f.stream_id) {
stream.set_priority(f.priority);
}
Ok(None)
}
Frame::RstStream(f) => self.process_rst_stream(f).map(Some),
Frame::Settings(f) => self.process_settings(&f),
Frame::PushPromise(f) => self.process_push_promise(&f),
Frame::Ping(f) => Ok(self.process_ping(f)),
Frame::GoAway(f) => Ok(Some(self.process_goaway(f))),
Frame::WindowUpdate(f) => self.process_window_update(f),
Frame::Continuation(f) => self.process_continuation(f),
Frame::Unknown { .. } => Ok(None), };
let max = self.local_settings.max_concurrent_streams as usize;
let threshold = std::cmp::min(max, 16_384).saturating_mul(2);
if self.streams.len() > threshold {
self.streams.prune_closed();
}
result
}
fn track_stream_id(&mut self, stream_id: u32) {
if stream_id > self.last_stream_id {
self.last_stream_id = stream_id;
}
}
fn stream_exceeds_sent_goaway(&self, stream_id: u32) -> bool {
self.sent_goaway_last_stream_id
.is_some_and(|last_stream_id| stream_id > last_stream_id)
}
fn stream_can_emit_queued_frames(&self, stream_id: u32) -> bool {
self.streams
.get(stream_id)
.is_some_and(|stream| stream.error_code().is_none())
}
fn process_data(&mut self, frame: DataFrame) -> Result<Option<ReceivedFrame>, H2Error> {
if self.streams.is_idle_stream_id(frame.stream_id) {
return Err(H2Error::protocol("DATA received on idle stream"));
}
let refused = self.stream_exceeds_sent_goaway(frame.stream_id);
if !refused {
self.track_stream_id(frame.stream_id);
}
let payload_len =
u32::try_from(frame.data.len()).map_err(|_| H2Error::frame_size("data too large"))?;
let window_delta = i32::try_from(payload_len)
.map_err(|_| H2Error::flow_control("data too large for window"))?;
if window_delta > self.recv_window {
return Err(H2Error::flow_control(
"data exceeds connection flow control window",
));
}
self.recv_window -= window_delta;
let low_watermark = DEFAULT_CONNECTION_WINDOW_SIZE / 2;
if self.recv_window < low_watermark {
let increment = i64::from(DEFAULT_CONNECTION_WINDOW_SIZE) - i64::from(self.recv_window);
let increment = u32::try_from(increment)
.map_err(|_| H2Error::flow_control("window increment too large"))?;
self.send_connection_window_update(increment)?;
}
let stream = self.streams.get_mut(frame.stream_id).ok_or_else(|| {
H2Error::stream(
frame.stream_id,
ErrorCode::StreamClosed,
"DATA received on closed stream",
)
})?;
stream.recv_data(payload_len, frame.end_stream)?;
if stream.state().can_recv() {
if let Some(increment) = stream.auto_window_update_increment() {
stream
.update_recv_window(i32::try_from(increment).map_err(|_| {
H2Error::flow_control("stream window increment too large")
})?)?;
self.pending_ops.push_back(PendingOp::WindowUpdate {
stream_id: frame.stream_id,
increment,
});
}
}
if refused {
Ok(None)
} else {
Ok(Some(ReceivedFrame::Data {
stream_id: frame.stream_id,
data: frame.data,
end_stream: frame.end_stream,
}))
}
}
fn process_headers(&mut self, frame: HeadersFrame) -> Result<Option<ReceivedFrame>, H2Error> {
let refused = self.stream_exceeds_sent_goaway(frame.stream_id);
{
let _ = self.streams.get_or_create(frame.stream_id)?;
}
if !refused {
self.track_stream_id(frame.stream_id);
}
let stream = self.streams.get_mut(frame.stream_id).ok_or_else(|| {
H2Error::connection(
ErrorCode::InternalError,
"stream disappeared after get_or_create",
)
})?;
stream.recv_headers(frame.end_stream, frame.end_headers)?;
if let Some(priority) = frame.priority {
stream.set_priority(priority);
}
stream.add_header_fragment(frame.header_block)?;
if frame.end_headers {
self.continuation_stream_id = None;
self.continuation_started_at = None;
let result = self.decode_headers(frame.stream_id, frame.end_stream);
if refused {
self.pending_ops.push_back(PendingOp::RstStream {
stream_id: frame.stream_id,
error_code: ErrorCode::RefusedStream,
});
result?; Ok(None)
} else {
result
}
} else {
self.continuation_stream_id = Some(frame.stream_id);
self.continuation_started_at = Some((self.time_getter)());
Ok(None)
}
}
fn process_continuation(
&mut self,
frame: ContinuationFrame,
) -> Result<Option<ReceivedFrame>, H2Error> {
if let Some(pending) = self.pending_push_promise {
if pending.associated_stream_id == frame.stream_id {
let promised_stream_id = pending.promised_stream_id;
let promised = self.streams.get_mut(promised_stream_id).ok_or_else(|| {
H2Error::stream(
promised_stream_id,
ErrorCode::StreamClosed,
"promised stream not found",
)
})?;
promised.add_header_fragment(frame.header_block)?;
if frame.end_headers {
self.pending_push_promise = None;
self.continuation_stream_id = None;
self.continuation_started_at = None;
return self.decode_push_promise(frame.stream_id, promised_stream_id);
}
return Ok(None);
}
}
let stream = self
.streams
.get_mut(frame.stream_id)
.ok_or_else(|| H2Error::protocol("CONTINUATION for unknown stream"))?;
stream.recv_continuation(frame.header_block, frame.end_headers)?;
if frame.end_headers {
self.continuation_stream_id = None;
self.continuation_started_at = None;
let end_stream = matches!(
stream.state(),
StreamState::HalfClosedRemote | StreamState::Closed
);
let refused = self.stream_exceeds_sent_goaway(frame.stream_id);
let result = self.decode_headers(frame.stream_id, end_stream);
if refused {
self.pending_ops.push_back(PendingOp::RstStream {
stream_id: frame.stream_id,
error_code: ErrorCode::RefusedStream,
});
result?; Ok(None)
} else {
result
}
} else {
Ok(None)
}
}
fn decode_headers(
&mut self,
stream_id: u32,
end_stream: bool,
) -> Result<Option<ReceivedFrame>, H2Error> {
let stream = self.streams.get_mut(stream_id).ok_or_else(|| {
H2Error::connection(ErrorCode::InternalError, "decode_headers missing stream")
})?;
let fragments = stream.take_header_fragments();
let total_len: usize = fragments.iter().map(Bytes::len).sum();
let max_fragment_size =
Stream::max_header_fragment_size_for(self.local_settings.max_header_list_size);
if total_len > max_fragment_size {
return Err(H2Error::stream(
stream_id,
ErrorCode::EnhanceYourCalm,
"accumulated header fragments too large",
));
}
let mut combined = BytesMut::with_capacity(total_len);
for fragment in fragments {
combined.extend_from_slice(&fragment);
}
let mut src = combined.freeze();
let headers = self.hpack_decoder.decode(&mut src)?;
Ok(Some(ReceivedFrame::Headers {
stream_id,
headers,
end_stream,
}))
}
fn decode_push_promise(
&mut self,
associated_stream_id: u32,
promised_stream_id: u32,
) -> Result<Option<ReceivedFrame>, H2Error> {
let promised = self.streams.get_mut(promised_stream_id).ok_or_else(|| {
H2Error::stream(
promised_stream_id,
ErrorCode::StreamClosed,
"promised stream not found",
)
})?;
let fragments = promised.take_header_fragments();
let total_len: usize = fragments.iter().map(Bytes::len).sum();
let max_fragment_size =
Stream::max_header_fragment_size_for(self.local_settings.max_header_list_size);
if total_len > max_fragment_size {
return Err(H2Error::stream(
promised_stream_id,
ErrorCode::EnhanceYourCalm,
"accumulated header fragments too large",
));
}
let mut combined = BytesMut::with_capacity(total_len);
for fragment in fragments {
combined.extend_from_slice(&fragment);
}
let mut src = combined.freeze();
let headers = self.hpack_decoder.decode(&mut src)?;
Ok(Some(ReceivedFrame::PushPromise {
stream_id: associated_stream_id,
promised_stream_id,
headers,
}))
}
fn process_rst_stream(&mut self, frame: RstStreamFrame) -> Result<ReceivedFrame, H2Error> {
if frame.stream_id == 0 {
return Err(H2Error::protocol("RST_STREAM with stream ID 0"));
}
if self.streams.is_idle_stream_id(frame.stream_id) {
return Err(H2Error::protocol("RST_STREAM received on idle stream"));
}
self.track_stream_id(frame.stream_id);
let elapsed = (self.time_getter)()
.saturating_duration_since(self.rst_stream_window_start)
.as_millis();
if elapsed >= self.rst_rate_limit.rst_window_ms {
self.rst_stream_count = 0;
self.rst_stream_window_start = (self.time_getter)();
}
if self.rst_stream_count >= self.rst_rate_limit.max_rst_streams {
return Err(H2Error::connection(
ErrorCode::EnhanceYourCalm,
"RST_STREAM flood detected",
));
}
self.rst_stream_count += 1;
if let Some(stream) = self.streams.get_mut(frame.stream_id) {
stream.reset(frame.error_code);
}
Ok(ReceivedFrame::Reset {
stream_id: frame.stream_id,
error_code: frame.error_code,
})
}
fn process_settings(
&mut self,
frame: &SettingsFrame,
) -> Result<Option<ReceivedFrame>, H2Error> {
if frame.ack {
return Ok(None);
}
for setting in &frame.settings {
if self.is_client && matches!(setting, Setting::EnablePush(_)) {
return Err(H2Error::protocol(
"server MUST NOT send SETTINGS_ENABLE_PUSH",
));
}
self.remote_settings.apply(*setting)?;
match setting {
Setting::InitialWindowSize(size) => {
self.streams.set_initial_window_size(*size)?;
}
Setting::HeaderTableSize(size) => {
let capped = (*size as usize).min(1024 * 1024);
self.hpack_encoder.set_max_table_size(capped);
}
Setting::MaxConcurrentStreams(max) => {
self.streams.set_max_concurrent_streams(*max);
}
Setting::MaxFrameSize(size) => {
let _ = size;
}
_ => {}
}
}
self.pending_ops.push_back(PendingOp::SettingsAck);
if !self.received_settings {
self.received_settings = true;
self.state = ConnectionState::Open;
}
Ok(None)
}
fn process_push_promise(
&mut self,
frame: &PushPromiseFrame,
) -> Result<Option<ReceivedFrame>, H2Error> {
if !self.is_client {
return Err(H2Error::protocol("server received PUSH_PROMISE"));
}
if !self.local_settings.enable_push {
return Err(H2Error::protocol("push not enabled"));
}
if frame.stream_id.is_multiple_of(2) {
return Err(H2Error::protocol("PUSH_PROMISE on server-initiated stream"));
}
if self.stream_exceeds_sent_goaway(frame.promised_stream_id) {
self.pending_ops.push_back(PendingOp::RstStream {
stream_id: frame.promised_stream_id,
error_code: ErrorCode::RefusedStream,
});
return Ok(None);
}
let assoc_state = match self.streams.get(frame.stream_id) {
Some(stream) => stream.state(),
None => {
return Err(H2Error::protocol("PUSH_PROMISE on unknown stream"));
}
};
if !matches!(
assoc_state,
StreamState::Open | StreamState::HalfClosedLocal
) {
let code = if assoc_state.is_closed() {
ErrorCode::StreamClosed
} else {
ErrorCode::ProtocolError
};
return Err(H2Error::stream(
frame.stream_id,
code,
"PUSH_PROMISE on stream not in open or half-closed (local) state",
));
}
let max_concurrent = self.local_settings.max_concurrent_streams;
if self.streams.active_count() as u32 >= max_concurrent {
return Err(H2Error::stream(
frame.promised_stream_id,
ErrorCode::RefusedStream,
"max concurrent streams exceeded",
));
}
let promised_stream_id = frame.promised_stream_id;
let promised_stream = self.streams.reserve_remote_stream(promised_stream_id)?;
promised_stream.add_header_fragment(frame.header_block.clone())?;
if frame.end_headers {
self.continuation_stream_id = None;
self.continuation_started_at = None;
self.decode_push_promise(frame.stream_id, promised_stream_id)
} else {
self.pending_push_promise = Some(PushPromiseAccumulator {
associated_stream_id: frame.stream_id,
promised_stream_id,
});
self.continuation_stream_id = Some(frame.stream_id);
self.continuation_started_at = Some((self.time_getter)());
Ok(None)
}
}
fn process_ping(&mut self, frame: PingFrame) -> Option<ReceivedFrame> {
if !frame.ack {
self.pending_ops
.push_back(PendingOp::PingAck(frame.opaque_data));
}
None
}
fn process_goaway(&mut self, frame: GoAwayFrame) -> ReceivedFrame {
self.goaway_received = true;
self.state = ConnectionState::Closing;
let effective_last_stream_id = self
.received_goaway_last_stream_id
.map_or(frame.last_stream_id, |previous| {
previous.min(frame.last_stream_id)
});
self.received_goaway_last_stream_id = Some(effective_last_stream_id);
for stream_id in self.streams.active_stream_ids() {
let is_local = (stream_id % 2 == 1) == self.is_client;
if is_local && stream_id > effective_last_stream_id {
if let Some(stream) = self.streams.get_mut(stream_id) {
stream.reset(ErrorCode::RefusedStream);
}
}
}
ReceivedFrame::GoAway {
last_stream_id: effective_last_stream_id,
error_code: frame.error_code,
debug_data: frame.debug_data,
}
}
fn process_window_update(
&mut self,
frame: WindowUpdateFrame,
) -> Result<Option<ReceivedFrame>, H2Error> {
let increment = i32::try_from(frame.increment)
.map_err(|_| H2Error::flow_control("window increment too large"))?;
if increment == 0 {
if frame.stream_id == 0 {
return Err(H2Error::protocol("WINDOW_UPDATE with zero increment"));
}
return Err(H2Error::stream(
frame.stream_id,
ErrorCode::ProtocolError,
"WINDOW_UPDATE with zero increment",
));
}
if frame.stream_id == 0 {
let new_window = i64::from(self.send_window) + i64::from(increment);
if new_window > i64::from(i32::MAX) {
return Err(H2Error::flow_control("connection window overflow"));
}
self.send_window = new_window as i32;
} else {
if self.streams.is_idle_stream_id(frame.stream_id) {
return Err(H2Error::protocol("WINDOW_UPDATE received on idle stream"));
}
if let Some(stream) = self.streams.get_mut(frame.stream_id) {
stream.update_send_window(increment)?;
}
}
Ok(None)
}
#[allow(clippy::too_many_lines)]
pub fn next_frame(&mut self) -> Option<Frame> {
let mut blocked_data = false;
let pending_len = self.pending_ops.len();
let mut skipped_ops = std::collections::VecDeque::new();
let mut newly_queued_ops = std::collections::VecDeque::new();
let mut returned_frame = None;
for _ in 0..pending_len {
let op = self.pending_ops.pop_front()?;
match op {
PendingOp::Settings(frame) => {
returned_frame = Some(Frame::Settings(frame));
break;
}
PendingOp::SettingsAck => {
returned_frame = Some(Frame::Settings(SettingsFrame::ack()));
break;
}
PendingOp::PingAck(data) => {
returned_frame = Some(Frame::Ping(PingFrame::ack(data)));
break;
}
PendingOp::WindowUpdate {
stream_id,
increment,
} => {
if stream_id != 0 && !self.stream_can_emit_queued_frames(stream_id) {
continue;
}
returned_frame = Some(Frame::WindowUpdate(WindowUpdateFrame::new(
stream_id, increment,
)));
break;
}
PendingOp::Headers {
stream_id,
headers,
end_stream,
} => {
if !self.stream_can_emit_queued_frames(stream_id) {
continue;
}
let mut encoded = BytesMut::new();
self.hpack_encoder.encode(&headers, &mut encoded);
let encoded = encoded.freeze();
let max_frame_size = self.remote_settings.max_frame_size as usize;
if encoded.len() <= max_frame_size {
returned_frame = Some(Frame::Headers(HeadersFrame::new(
stream_id, encoded, end_stream, true, )));
break;
}
let first_chunk = encoded.slice(..max_frame_size);
let remaining = encoded.slice(max_frame_size..);
let mut offset = 0;
while offset < remaining.len() {
let chunk_end = (offset + max_frame_size).min(remaining.len());
let chunk = remaining.slice(offset..chunk_end);
let is_last = chunk_end == remaining.len();
newly_queued_ops.push_back(PendingOp::Continuation {
stream_id,
header_block: chunk,
end_headers: is_last,
});
offset = chunk_end;
}
returned_frame = Some(Frame::Headers(HeadersFrame::new(
stream_id,
first_chunk,
end_stream,
false, )));
break;
}
PendingOp::Continuation {
stream_id,
header_block,
end_headers,
} => {
if !self.stream_can_emit_queued_frames(stream_id) {
continue;
}
returned_frame = Some(Frame::Continuation(ContinuationFrame {
stream_id,
header_block,
end_headers,
}));
break;
}
PendingOp::Data {
stream_id,
data,
end_stream,
} => {
let stream_avail = match self.streams.get(stream_id) {
Some(stream) if stream.error_code().is_none() => {
stream.send_window().max(0).cast_unsigned()
}
_ => continue,
};
let conn_avail = self.send_window.max(0).cast_unsigned();
let frame_size_limit = self.remote_settings.max_frame_size;
let max_send = conn_avail.min(stream_avail).min(frame_size_limit) as usize;
if max_send == 0 && !data.is_empty() {
skipped_ops.push_back(PendingOp::Data {
stream_id,
data,
end_stream,
});
blocked_data = true;
continue;
}
let send_len = data.len().min(max_send);
let (to_send, remainder) = if send_len < data.len() {
(data.slice(..send_len), Some(data.slice(send_len..)))
} else {
(data, None)
};
let actually_end = end_stream && remainder.is_none();
if let Some(rest) = remainder {
skipped_ops.push_back(PendingOp::Data {
stream_id,
data: rest,
end_stream,
});
}
let consumed = u32::try_from(to_send.len())
.expect("send_len already clamped to u32 range");
self.send_window -= consumed.cast_signed();
if let Some(stream) = self.streams.get_mut(stream_id) {
stream.consume_send_window(consumed);
}
returned_frame = Some(Frame::Data(DataFrame::new(
stream_id,
to_send,
actually_end,
)));
break;
}
PendingOp::RstStream {
stream_id,
error_code,
} => {
returned_frame =
Some(Frame::RstStream(RstStreamFrame::new(stream_id, error_code)));
break;
}
PendingOp::GoAway {
last_stream_id,
error_code,
debug_data,
} => {
let mut frame = GoAwayFrame::new(last_stream_id, error_code);
frame.debug_data = debug_data;
returned_frame = Some(Frame::GoAway(frame));
break;
}
}
}
for op in skipped_ops.into_iter().rev() {
self.pending_ops.push_front(op);
}
for op in newly_queued_ops.into_iter().rev() {
self.pending_ops.push_front(op);
}
if returned_frame.is_some() {
return returned_frame;
}
if blocked_data {
return None;
}
None
}
#[must_use]
pub fn has_pending_frames(&self) -> bool {
!self.pending_ops.is_empty()
}
pub fn send_connection_window_update(&mut self, increment: u32) -> Result<(), H2Error> {
if increment == 0 {
return Err(H2Error::flow_control(
"WINDOW_UPDATE increment must be non-zero (RFC 7540 ยง6.9)",
));
}
let delta = i32::try_from(increment)
.map_err(|_| H2Error::flow_control("window increment too large"))?;
let new_window = i64::from(self.recv_window) + i64::from(delta);
if new_window > i64::from(i32::MAX) {
return Err(H2Error::flow_control("connection window overflow"));
}
self.recv_window = new_window as i32;
self.pending_ops.push_back(PendingOp::WindowUpdate {
stream_id: 0,
increment,
});
Ok(())
}
pub fn send_stream_window_update(
&mut self,
stream_id: u32,
increment: u32,
) -> Result<(), H2Error> {
if increment == 0 {
return Err(H2Error::flow_control(
"WINDOW_UPDATE increment must be non-zero (RFC 7540 ยง6.9)",
));
}
let delta = i32::try_from(increment)
.map_err(|_| H2Error::flow_control("window increment too large"))?;
if let Some(stream) = self.streams.get_mut(stream_id) {
stream.update_recv_window(delta)?;
} else {
return Ok(());
}
self.pending_ops.push_back(PendingOp::WindowUpdate {
stream_id,
increment,
});
Ok(())
}
pub fn prune_closed_streams(&mut self) {
self.streams.prune_closed();
}
}
#[derive(Debug)]
#[allow(missing_docs)]
pub enum ReceivedFrame {
Headers {
stream_id: u32,
headers: Vec<Header>,
end_stream: bool,
},
PushPromise {
stream_id: u32,
promised_stream_id: u32,
headers: Vec<Header>,
},
Data {
stream_id: u32,
data: Bytes,
end_stream: bool,
},
Reset {
stream_id: u32,
error_code: ErrorCode,
},
GoAway {
last_stream_id: u32,
error_code: ErrorCode,
debug_data: Bytes,
},
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bytes::Bytes;
use crate::http::h2::settings;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Mutex, OnceLock};
use std::time::Duration;
static TEST_TIME_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
static TEST_NOW_BASE: OnceLock<Instant> = OnceLock::new();
static TEST_NOW_OFFSET_MS: AtomicU64 = AtomicU64::new(0);
fn lock_test_clock() -> std::sync::MutexGuard<'static, ()> {
TEST_TIME_LOCK
.get_or_init(|| Mutex::new(()))
.lock()
.expect("test time lock poisoned")
}
fn set_test_time_offset(duration: Duration) {
let millis = u64::try_from(duration.as_millis()).expect("duration fits u64 millis");
TEST_NOW_OFFSET_MS.store(millis, Ordering::SeqCst);
}
fn advance_test_time(duration: Duration) {
let millis = u64::try_from(duration.as_millis()).expect("duration fits u64 millis");
TEST_NOW_OFFSET_MS.fetch_add(millis, Ordering::SeqCst);
}
fn test_now() -> Instant {
TEST_NOW_BASE
.get_or_init(Instant::now)
.checked_add(Duration::from_millis(
TEST_NOW_OFFSET_MS.load(Ordering::SeqCst),
))
.expect("test instant overflow")
}
#[test]
fn data_frame_triggers_connection_window_update_on_low_watermark() {
let mut conn = Connection::server(Settings::default());
let payload_len = (DEFAULT_CONNECTION_WINDOW_SIZE / 2) + 2;
let payload_len_usize = usize::try_from(payload_len).expect("payload_len non-negative");
let payload_len_u32 = u32::try_from(payload_len).expect("payload_len fits u32");
let data = Bytes::from(vec![0_u8; payload_len_usize]);
let headers = Frame::Headers(HeadersFrame::new(1, Bytes::new(), false, true));
let frame = Frame::Data(DataFrame::new(1, data, false));
conn.process_frame(headers).expect("process headers frame");
conn.process_frame(frame).expect("process data frame");
assert!(conn.has_pending_frames(), "expected WINDOW_UPDATE(s)");
let mut found_connection_update = false;
while let Some(pending) = conn.next_frame() {
if let Frame::WindowUpdate(update) = pending {
if update.stream_id == 0 {
assert_eq!(update.increment, payload_len_u32);
found_connection_update = true;
}
}
}
assert!(
found_connection_update,
"expected connection-level WINDOW_UPDATE"
);
}
#[test]
fn data_frame_exceeding_connection_window_errors() {
let mut conn = Connection::server(Settings::default());
conn.recv_window = 1;
let headers = Frame::Headers(HeadersFrame::new(1, Bytes::new(), false, true));
conn.process_frame(headers).expect("process headers frame");
let data = Bytes::from(vec![0_u8; 2]);
let frame = Frame::Data(DataFrame::new(1, data, false));
let result = conn.process_frame(frame);
assert!(result.is_err());
let err = result.expect_err("flow control error");
assert_eq!(err.code, ErrorCode::FlowControlError);
}
#[test]
fn data_on_closed_stream_still_decrements_connection_window() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
let headers = Frame::Headers(HeadersFrame::new(1, Bytes::new(), false, true));
conn.process_frame(headers).unwrap();
let rst = Frame::RstStream(RstStreamFrame::new(1, ErrorCode::Cancel));
conn.process_frame(rst).unwrap();
assert_eq!(conn.stream(1).unwrap().state(), StreamState::Closed);
let window_before = conn.recv_window();
let payload = Bytes::from(vec![0_u8; 100]);
let frame = Frame::Data(DataFrame::new(1, payload, false));
let err = conn.process_frame(frame).unwrap_err();
assert_eq!(err.code, ErrorCode::StreamClosed);
assert_eq!(
conn.recv_window(),
window_before - 100,
"connection recv_window must be decremented even on stream-level errors"
);
}
#[test]
fn test_frame_codec_decode() {
let mut codec = FrameCodec::new();
let frame = PingFrame::new([1, 2, 3, 4, 5, 6, 7, 8]);
let mut buf = BytesMut::new();
Frame::Ping(frame).encode(&mut buf);
let decoded = codec.decode(&mut buf).unwrap().unwrap();
match decoded {
Frame::Ping(ping) => {
assert_eq!(ping.opaque_data, [1, 2, 3, 4, 5, 6, 7, 8]);
assert!(!ping.ack);
}
_ => panic!("expected PING frame"),
}
}
#[test]
fn test_frame_codec_skips_unknown_frame_type() {
let mut codec = FrameCodec::new();
let mut buf = BytesMut::new();
FrameHeader {
length: 3,
frame_type: 0xFF,
flags: 0,
stream_id: 0,
}
.write(&mut buf);
buf.extend_from_slice(&[1, 2, 3]);
let ping = PingFrame::new([9, 8, 7, 6, 5, 4, 3, 2]);
Frame::Ping(ping).encode(&mut buf);
let decoded = codec.decode(&mut buf).unwrap().unwrap();
match decoded {
Frame::Ping(p) => assert_eq!(p.opaque_data, [9, 8, 7, 6, 5, 4, 3, 2]),
_ => panic!("expected PING frame"),
}
}
#[test]
fn test_frame_codec_unknown_frame_without_followup_returns_none() {
let mut codec = FrameCodec::new();
let mut buf = BytesMut::new();
FrameHeader {
length: 2,
frame_type: 0xFE,
flags: 0,
stream_id: 0,
}
.write(&mut buf);
buf.extend_from_slice(&[0xAA, 0xBB]);
let decoded = codec.decode(&mut buf).unwrap();
assert!(decoded.is_none(), "expected no decoded frame");
assert!(buf.is_empty(), "unknown frame bytes should be consumed");
}
#[test]
fn test_connection_client_settings() {
let mut conn = Connection::client(Settings::client());
conn.queue_initial_settings();
assert!(conn.has_pending_frames());
let frame = conn.next_frame().unwrap();
match frame {
Frame::Settings(settings) => {
assert!(!settings.ack);
assert!(
settings
.settings
.iter()
.any(|setting| matches!(setting, Setting::EnablePush(false)))
);
}
_ => panic!("expected SETTINGS frame"),
}
}
#[test]
fn test_connection_process_settings() {
let mut conn = Connection::client(Settings::client());
let settings = SettingsFrame::new(vec![
Setting::MaxConcurrentStreams(100),
Setting::InitialWindowSize(32768),
]);
conn.process_frame(Frame::Settings(settings)).unwrap();
assert!(conn.has_pending_frames());
let frame = conn.next_frame().unwrap();
match frame {
Frame::Settings(settings) => {
assert!(settings.ack);
}
_ => panic!("expected SETTINGS ACK"),
}
assert_eq!(conn.remote_settings().max_concurrent_streams, 100);
assert_eq!(conn.remote_settings().initial_window_size, 32768);
}
#[test]
fn settings_max_concurrent_streams_constrains_open_stream() {
let mut conn = Connection::client(Settings::client());
let settings = SettingsFrame::new(vec![Setting::MaxConcurrentStreams(2)]);
conn.process_frame(Frame::Settings(settings)).unwrap();
let _ = conn.next_frame();
let headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
conn.open_stream(headers.clone(), false).unwrap();
conn.open_stream(headers.clone(), false).unwrap();
let result = conn.open_stream(headers, false);
assert!(
result.is_err(),
"third stream must be refused when peer MaxConcurrentStreams=2"
);
}
#[test]
fn test_connection_client_rejects_server_enable_push_setting() {
let mut conn = Connection::client(Settings::client());
let settings = SettingsFrame::new(vec![Setting::EnablePush(false)]);
let err = conn.process_frame(Frame::Settings(settings)).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
assert!(
!conn.has_pending_frames(),
"invalid settings must not be ACKed"
);
}
#[test]
fn test_connection_server_initial_settings_omit_enable_push() {
let mut local = Settings::server();
local.enable_push = false;
let mut conn = Connection::server(local);
conn.queue_initial_settings();
let frame = conn.next_frame().expect("expected initial settings frame");
match frame {
Frame::Settings(settings) => {
assert!(
!settings
.settings
.iter()
.any(|setting| matches!(setting, Setting::EnablePush(_)))
);
}
_ => panic!("expected SETTINGS frame"),
}
}
#[test]
fn test_connection_process_ping() {
let mut conn = Connection::client(Settings::client());
let ping = PingFrame::new([1, 2, 3, 4, 5, 6, 7, 8]);
conn.process_frame(Frame::Ping(ping)).unwrap();
let frame = conn.next_frame().unwrap();
match frame {
Frame::Ping(ping) => {
assert!(ping.ack);
assert_eq!(ping.opaque_data, [1, 2, 3, 4, 5, 6, 7, 8]);
}
_ => panic!("expected PING ACK"),
}
}
#[test]
fn test_connection_open_stream() {
let mut conn = Connection::client(Settings::client());
conn.state = ConnectionState::Open;
let headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
let stream_id = conn.open_stream(headers, false).unwrap();
assert_eq!(stream_id, 1);
let frame = conn.next_frame().unwrap();
match frame {
Frame::Headers(h) => {
assert_eq!(h.stream_id, 1);
assert!(!h.end_stream);
assert!(h.end_headers);
}
_ => panic!("expected HEADERS frame"),
}
}
#[test]
fn data_frame_triggers_stream_window_update_on_low_watermark() {
let mut conn = Connection::server(Settings::default());
let headers = Frame::Headers(HeadersFrame::new(1, Bytes::new(), false, true));
conn.process_frame(headers).expect("process headers");
let initial_window = settings::DEFAULT_INITIAL_WINDOW_SIZE;
let payload_len = initial_window / 2 + 2;
let data = Bytes::from(vec![0_u8; payload_len as usize]);
let frame = Frame::Data(DataFrame::new(1, data, false));
conn.process_frame(frame).expect("process data");
let mut found_stream_update = false;
while let Some(f) = conn.next_frame() {
if let Frame::WindowUpdate(wu) = f {
if wu.stream_id == 1 {
found_stream_update = true;
assert_eq!(wu.increment, payload_len);
}
}
}
assert!(
found_stream_update,
"expected stream-level WINDOW_UPDATE for stream 1"
);
}
#[test]
fn data_frame_no_stream_window_update_when_above_watermark() {
let mut conn = Connection::server(Settings::default());
let headers = Frame::Headers(HeadersFrame::new(1, Bytes::new(), false, true));
conn.process_frame(headers).expect("process headers");
let data = Bytes::from(vec![0_u8; 100]);
let frame = Frame::Data(DataFrame::new(1, data, false));
conn.process_frame(frame).expect("process data");
while let Some(f) = conn.next_frame() {
if let Frame::WindowUpdate(wu) = f {
assert_ne!(wu.stream_id, 1, "unexpected stream-level WINDOW_UPDATE");
}
}
}
#[test]
fn send_data_respects_send_window() {
let mut conn = Connection::client(Settings::client());
conn.state = ConnectionState::Open;
let headers = vec![
Header::new(":method", "POST"),
Header::new(":path", "/upload"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
let stream_id = conn.open_stream(headers, false).unwrap();
let _ = conn.next_frame().unwrap();
conn.send_window = 100;
let data = Bytes::from(vec![0xAB_u8; 300]);
conn.send_data(stream_id, data, true).unwrap();
let frame1 = conn.next_frame().expect("expected first DATA frame");
match frame1 {
Frame::Data(d) => {
assert_eq!(d.data.len(), 100, "should be clamped to send window");
assert!(!d.end_stream, "not the final chunk");
}
other => panic!("expected DATA frame, got {other:?}"),
}
conn.send_window = 300;
let frame2 = conn.next_frame().expect("expected second DATA frame");
match frame2 {
Frame::Data(d) => {
assert_eq!(d.data.len(), 200, "remaining 200 bytes");
assert!(d.end_stream, "final chunk should carry end_stream");
}
other => panic!("expected DATA frame, got {other:?}"),
}
assert!(!conn.has_pending_frames(), "all data should be sent");
}
#[test]
fn send_data_respects_stream_send_window() {
let mut conn = Connection::client(Settings::client());
conn.state = ConnectionState::Open;
let headers = vec![
Header::new(":method", "POST"),
Header::new(":path", "/"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
let stream_id = conn.open_stream(headers, false).unwrap();
let _ = conn.next_frame().unwrap();
conn.stream_mut(stream_id)
.unwrap()
.consume_send_window(65535 - 50);
let data = Bytes::from(vec![0xCD_u8; 200]);
conn.send_data(stream_id, data, true).unwrap();
let frame1 = conn.next_frame().expect("expected first DATA frame");
match frame1 {
Frame::Data(d) => {
assert_eq!(d.data.len(), 50, "clamped to stream send window");
assert!(!d.end_stream);
}
other => panic!("expected DATA frame, got {other:?}"),
}
conn.stream_mut(stream_id)
.unwrap()
.update_send_window(200)
.unwrap();
let frame2 = conn.next_frame().expect("expected second DATA frame");
match frame2 {
Frame::Data(d) => {
assert_eq!(d.data.len(), 150);
assert!(d.end_stream);
}
other => panic!("expected DATA frame, got {other:?}"),
}
}
#[test]
fn send_data_respects_max_frame_size() {
let mut conn = Connection::client(Settings::client());
conn.state = ConnectionState::Open;
conn.remote_settings.max_frame_size = 100;
let headers = vec![
Header::new(":method", "POST"),
Header::new(":path", "/"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
let stream_id = conn.open_stream(headers, false).unwrap();
let _ = conn.next_frame().unwrap();
let data = Bytes::from(vec![0xEE_u8; 300]);
conn.send_data(stream_id, data, true).unwrap();
let frame1 = conn.next_frame().expect("expected first DATA frame");
match frame1 {
Frame::Data(d) => {
assert_eq!(d.data.len(), 100, "clamped to max_frame_size");
assert!(!d.end_stream);
}
other => panic!("expected DATA frame, got {other:?}"),
}
let frame2 = conn.next_frame().expect("expected second DATA frame");
match frame2 {
Frame::Data(d) => {
assert_eq!(d.data.len(), 100);
assert!(!d.end_stream);
}
other => panic!("expected DATA frame, got {other:?}"),
}
let frame3 = conn.next_frame().expect("expected third DATA frame");
match frame3 {
Frame::Data(d) => {
assert_eq!(d.data.len(), 100);
assert!(d.end_stream);
}
other => panic!("expected DATA frame, got {other:?}"),
}
assert!(!conn.has_pending_frames());
}
#[test]
fn final_data_flushes_after_stream_enters_closed_state() {
let mut conn = Connection::client(Settings::client());
conn.state = ConnectionState::Open;
let headers = vec![
Header::new(":method", "POST"),
Header::new(":path", "/upload"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
let stream_id = conn.open_stream(headers, false).unwrap();
let _ = conn.next_frame().unwrap();
let response = Frame::Headers(HeadersFrame::new(stream_id, Bytes::new(), true, true));
conn.process_frame(response).unwrap();
assert_eq!(
conn.stream(stream_id).unwrap().state(),
StreamState::HalfClosedRemote
);
conn.send_data(stream_id, Bytes::from_static(b"payload"), true)
.unwrap();
assert_eq!(conn.stream(stream_id).unwrap().state(), StreamState::Closed);
let frame = conn
.next_frame()
.expect("final DATA must still be emitted after local close");
match frame {
Frame::Data(data) => {
assert_eq!(data.stream_id, stream_id);
assert_eq!(data.data, Bytes::from_static(b"payload"));
assert!(data.end_stream);
}
other => panic!("expected DATA frame, got {other:?}"),
}
}
#[test]
fn large_headers_use_continuation_frames() {
let mut conn = Connection::client(Settings::client());
conn.state = ConnectionState::Open;
conn.remote_settings.max_frame_size = 50;
let mut headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/some/very/long/path/that/exceeds/frame/size"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
for i in 0..10 {
headers.push(Header::new(
format!("x-custom-header-{i}"),
format!("value-{i}"),
));
}
let stream_id = conn.open_stream(headers, true).unwrap();
let frame1 = conn.next_frame().expect("expected HEADERS frame");
match &frame1 {
Frame::Headers(h) => {
assert_eq!(h.stream_id, stream_id);
assert!(h.end_stream);
assert!(!h.end_headers, "should have CONTINUATION following");
assert_eq!(h.header_block.len(), 50);
}
other => panic!("expected HEADERS frame, got {other:?}"),
}
let mut continuation_count = 0;
let mut last_end_headers = false;
while let Some(frame) = conn.next_frame() {
match frame {
Frame::Continuation(c) => {
assert_eq!(c.stream_id, stream_id);
continuation_count += 1;
last_end_headers = c.end_headers;
if c.end_headers {
break;
}
}
other => panic!("expected CONTINUATION frame, got {other:?}"),
}
}
assert!(
continuation_count >= 1,
"should have at least one CONTINUATION"
);
assert!(last_end_headers, "last frame should have end_headers=true");
}
#[test]
fn push_promise_rejected_when_disabled() {
let mut conn = Connection::client(Settings::client());
conn.state = ConnectionState::Open;
let headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
let stream_id = conn.open_stream(headers, false).unwrap();
let frame = Frame::PushPromise(PushPromiseFrame {
stream_id,
promised_stream_id: 2,
header_block: Bytes::new(),
end_headers: true,
});
let err = conn.process_frame(frame).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn push_promise_creates_reserved_stream() {
let mut settings = Settings::client();
settings.enable_push = true;
let mut conn = Connection::client(settings);
conn.state = ConnectionState::Open;
let headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
let stream_id = conn.open_stream(headers, false).unwrap();
let frame = Frame::PushPromise(PushPromiseFrame {
stream_id,
promised_stream_id: 2,
header_block: Bytes::new(),
end_headers: true,
});
let received = conn.process_frame(frame).unwrap().unwrap();
match received {
ReceivedFrame::PushPromise {
promised_stream_id, ..
} => assert_eq!(promised_stream_id, 2),
other => panic!("expected PushPromise frame, got {other:?}"),
}
let promised = conn.stream(2).expect("promised stream exists");
assert_eq!(promised.state(), StreamState::ReservedRemote);
}
#[test]
fn push_promise_continuation_accumulates() {
let mut settings = Settings::client();
settings.enable_push = true;
let mut conn = Connection::client(settings);
conn.state = ConnectionState::Open;
let headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
let stream_id = conn.open_stream(headers, false).unwrap();
let mut promise_headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/pushed"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
let mut encoded = BytesMut::new();
conn.hpack_encoder.encode(&promise_headers, &mut encoded);
if encoded.len() < 2 {
promise_headers.push(Header::new("x-extra", "1"));
encoded.clear();
conn.hpack_encoder.encode(&promise_headers, &mut encoded);
}
assert!(encoded.len() >= 2);
let encoded = encoded.freeze();
let split = encoded.len() / 2;
let first = encoded.slice(..split);
let second = encoded.slice(split..);
let push = Frame::PushPromise(PushPromiseFrame {
stream_id,
promised_stream_id: 2,
header_block: first,
end_headers: false,
});
assert!(conn.process_frame(push).unwrap().is_none());
let continuation = Frame::Continuation(ContinuationFrame {
stream_id,
header_block: second,
end_headers: true,
});
let received = conn.process_frame(continuation).unwrap().unwrap();
match received {
ReceivedFrame::PushPromise {
promised_stream_id,
headers: decoded,
..
} => {
assert_eq!(promised_stream_id, 2);
assert_eq!(decoded, promise_headers);
}
other => panic!("expected PushPromise frame, got {other:?}"),
}
}
#[test]
fn push_promise_rejected_on_server_connection() {
let mut conn = Connection::server(Settings::server());
conn.state = ConnectionState::Open;
let frame = Frame::PushPromise(PushPromiseFrame {
stream_id: 1,
promised_stream_id: 2,
header_block: Bytes::new(),
end_headers: true,
});
let err = conn.process_frame(frame).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn push_promise_rejected_for_invalid_promised_id() {
let mut settings = Settings::client();
settings.enable_push = true;
let mut conn = Connection::client(settings);
conn.state = ConnectionState::Open;
let headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
let stream_id = conn.open_stream(headers, false).unwrap();
let frame = Frame::PushPromise(PushPromiseFrame {
stream_id,
promised_stream_id: 3,
header_block: Bytes::new(),
end_headers: true,
});
let err = conn.process_frame(frame).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn push_promise_rejected_for_unknown_associated_stream() {
let mut settings = Settings::client();
settings.enable_push = true;
let mut conn = Connection::client(settings);
conn.state = ConnectionState::Open;
let frame = Frame::PushPromise(PushPromiseFrame {
stream_id: 1,
promised_stream_id: 2,
header_block: Bytes::new(),
end_headers: true,
});
let err = conn.process_frame(frame).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
assert_eq!(err.stream_id, None);
}
#[test]
fn continuation_timeout_not_triggered_when_no_continuation() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
assert!(conn.check_continuation_timeout().is_ok());
assert!(!conn.is_awaiting_continuation());
}
#[test]
fn continuation_timeout_not_triggered_when_within_limit() {
let _clock = lock_test_clock();
set_test_time_offset(Duration::ZERO);
let settings = Settings {
continuation_timeout_ms: 5000, ..Default::default()
};
let mut conn = Connection::server_with_time_getter(settings, test_now);
conn.state = ConnectionState::Open;
let headers = Frame::Headers(HeadersFrame::new(1, Bytes::new(), false, false));
let result = conn.process_frame(headers);
assert!(result.is_ok());
assert!(conn.is_awaiting_continuation());
advance_test_time(Duration::from_millis(10));
assert!(conn.check_continuation_timeout().is_ok());
assert!(conn.is_awaiting_continuation());
}
#[test]
fn continuation_clears_timeout_on_completion() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
let headers = Frame::Headers(HeadersFrame::new(1, Bytes::new(), false, false));
conn.process_frame(headers).unwrap();
assert!(conn.is_awaiting_continuation());
assert!(conn.continuation_started_at.is_some());
let continuation = Frame::Continuation(ContinuationFrame {
stream_id: 1,
header_block: Bytes::new(),
end_headers: true,
});
conn.process_frame(continuation).unwrap();
assert!(!conn.is_awaiting_continuation());
assert!(conn.continuation_started_at.is_none());
}
#[test]
fn continuation_timeout_triggers_after_expiry() {
let _clock = lock_test_clock();
set_test_time_offset(Duration::ZERO);
let settings = Settings {
continuation_timeout_ms: 50, ..Default::default()
};
let mut conn = Connection::server_with_time_getter(settings, test_now);
conn.state = ConnectionState::Open;
let headers = Frame::Headers(HeadersFrame::new(1, Bytes::new(), false, false));
conn.process_frame(headers).unwrap();
assert!(conn.is_awaiting_continuation());
advance_test_time(Duration::from_millis(60));
let err = conn.check_continuation_timeout().unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
assert!(err.message.contains("CONTINUATION timeout"));
assert!(!conn.is_awaiting_continuation());
assert!(conn.continuation_started_at.is_none());
}
#[test]
fn continuation_timeout_on_next_frame() {
let _clock = lock_test_clock();
set_test_time_offset(Duration::ZERO);
let settings = Settings {
continuation_timeout_ms: 50, ..Default::default()
};
let mut conn = Connection::server_with_time_getter(settings, test_now);
conn.state = ConnectionState::Open;
let headers = Frame::Headers(HeadersFrame::new(1, Bytes::new(), false, false));
conn.process_frame(headers).unwrap();
advance_test_time(Duration::from_millis(60));
let continuation = Frame::Continuation(ContinuationFrame {
stream_id: 1,
header_block: Bytes::new(),
end_headers: true,
});
let err = conn.process_frame(continuation).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
assert!(err.message.contains("CONTINUATION timeout"));
}
#[test]
fn push_promise_continuation_timeout() {
let _clock = lock_test_clock();
set_test_time_offset(Duration::ZERO);
let mut settings = Settings::client();
settings.enable_push = true;
settings.continuation_timeout_ms = 50;
let mut conn = Connection::client_with_time_getter(settings, test_now);
conn.state = ConnectionState::Open;
let headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
let stream_id = conn.open_stream(headers, false).unwrap();
let _ = conn.next_frame();
let push = Frame::PushPromise(PushPromiseFrame {
stream_id,
promised_stream_id: 2,
header_block: Bytes::new(),
end_headers: false,
});
conn.process_frame(push).unwrap();
assert!(conn.is_awaiting_continuation());
advance_test_time(Duration::from_millis(60));
let err = conn.check_continuation_timeout().unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
assert!(err.message.contains("CONTINUATION timeout"));
}
#[test]
fn push_promise_rejected_on_closed_stream() {
let mut settings = Settings::client();
settings.enable_push = true;
let mut conn = Connection::client(settings);
conn.state = ConnectionState::Open;
let headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
let stream_id = conn.open_stream(headers, true).unwrap(); let _ = conn.next_frame();
let response = Frame::Headers(HeadersFrame::new(stream_id, Bytes::new(), true, true));
conn.process_frame(response).unwrap();
assert_eq!(conn.stream(stream_id).unwrap().state(), StreamState::Closed);
let frame = Frame::PushPromise(PushPromiseFrame {
stream_id,
promised_stream_id: 2,
header_block: Bytes::new(),
end_headers: true,
});
let err = conn.process_frame(frame).unwrap_err();
assert_eq!(err.code, ErrorCode::StreamClosed);
}
#[test]
fn push_promise_rejected_on_half_closed_remote_stream() {
let mut settings = Settings::client();
settings.enable_push = true;
let mut conn = Connection::client(settings);
conn.state = ConnectionState::Open;
let headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
let stream_id = conn.open_stream(headers, false).unwrap();
let _ = conn.next_frame();
let response = Frame::Headers(HeadersFrame::new(stream_id, Bytes::new(), true, true));
conn.process_frame(response).unwrap();
assert_eq!(
conn.stream(stream_id).unwrap().state(),
StreamState::HalfClosedRemote
);
let frame = Frame::PushPromise(PushPromiseFrame {
stream_id,
promised_stream_id: 2,
header_block: Bytes::new(),
end_headers: true,
});
let err = conn.process_frame(frame).unwrap_err();
assert_eq!(
err.code,
ErrorCode::ProtocolError,
"PUSH_PROMISE on half-closed (remote) stream must be PROTOCOL_ERROR"
);
}
#[test]
fn push_promise_enforces_max_concurrent_streams() {
let mut settings = Settings::client();
settings.enable_push = true;
settings.max_concurrent_streams = 3; let mut conn = Connection::client(settings);
conn.state = ConnectionState::Open;
let headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
let stream_id = conn.open_stream(headers, false).unwrap();
let _ = conn.next_frame();
let push1 = Frame::PushPromise(PushPromiseFrame {
stream_id,
promised_stream_id: 2,
header_block: Bytes::new(),
end_headers: true,
});
assert!(conn.process_frame(push1).is_ok());
let push2 = Frame::PushPromise(PushPromiseFrame {
stream_id,
promised_stream_id: 4,
header_block: Bytes::new(),
end_headers: true,
});
assert!(conn.process_frame(push2).is_ok());
let push3 = Frame::PushPromise(PushPromiseFrame {
stream_id,
promised_stream_id: 6,
header_block: Bytes::new(),
end_headers: true,
});
let err = conn.process_frame(push3).unwrap_err();
assert_eq!(err.code, ErrorCode::RefusedStream);
}
#[test]
fn push_promise_rejected_for_duplicate_stream_id() {
let mut settings = Settings::client();
settings.enable_push = true;
let mut conn = Connection::client(settings);
conn.state = ConnectionState::Open;
let headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
let stream_id = conn.open_stream(headers, false).unwrap();
let _ = conn.next_frame();
let push1 = Frame::PushPromise(PushPromiseFrame {
stream_id,
promised_stream_id: 2,
header_block: Bytes::new(),
end_headers: true,
});
assert!(conn.process_frame(push1).is_ok());
let push2 = Frame::PushPromise(PushPromiseFrame {
stream_id,
promised_stream_id: 2,
header_block: Bytes::new(),
end_headers: true,
});
let err = conn.process_frame(push2).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn push_promise_monotonic_stream_id() {
let mut settings = Settings::client();
settings.enable_push = true;
let mut conn = Connection::client(settings);
conn.state = ConnectionState::Open;
let headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
let stream_id = conn.open_stream(headers, false).unwrap();
let _ = conn.next_frame();
let push1 = Frame::PushPromise(PushPromiseFrame {
stream_id,
promised_stream_id: 4,
header_block: Bytes::new(),
end_headers: true,
});
assert!(conn.process_frame(push1).is_ok());
let push2 = Frame::PushPromise(PushPromiseFrame {
stream_id,
promised_stream_id: 2,
header_block: Bytes::new(),
end_headers: true,
});
let err = conn.process_frame(push2).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn push_promise_attack_flood_bounded() {
let mut settings = Settings::client();
settings.enable_push = true;
settings.max_concurrent_streams = 10;
let mut conn = Connection::client(settings);
conn.state = ConnectionState::Open;
let headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
let stream_id = conn.open_stream(headers, false).unwrap();
let _ = conn.next_frame();
let mut accepted = 0;
let mut rejected = 0;
for i in 0..100 {
let promised_id = (i + 1) * 2; let push = Frame::PushPromise(PushPromiseFrame {
stream_id,
promised_stream_id: promised_id,
header_block: Bytes::new(),
end_headers: true,
});
match conn.process_frame(push) {
Ok(_) => accepted += 1,
Err(e) if e.code == ErrorCode::RefusedStream => rejected += 1,
Err(e) => panic!("unexpected error: {e:?}"),
}
}
assert_eq!(
accepted, 9,
"should accept max_concurrent_streams - 1 pushes"
);
assert_eq!(rejected, 91, "should reject the rest");
}
#[test]
fn push_promise_on_server_initiated_stream_rejected() {
let mut settings = Settings::client();
settings.enable_push = true;
let mut conn = Connection::client(settings);
conn.state = ConnectionState::Open;
let headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
let _ = conn.open_stream(headers, false).unwrap();
let _ = conn.next_frame();
let frame = Frame::PushPromise(PushPromiseFrame {
stream_id: 2, promised_stream_id: 4,
header_block: Bytes::new(),
end_headers: true,
});
let err = conn.process_frame(frame).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn test_settings_ack_is_no_op() {
let mut conn = Connection::client(Settings::client());
conn.state = ConnectionState::Open;
let ack_frame = Frame::Settings(SettingsFrame::ack());
let result = conn.process_frame(ack_frame);
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[test]
fn test_settings_updates_remote_settings() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
assert_eq!(conn.remote_settings().max_concurrent_streams, 256);
assert_eq!(
conn.remote_settings().initial_window_size,
settings::DEFAULT_INITIAL_WINDOW_SIZE
);
let settings = SettingsFrame::new(vec![
Setting::MaxConcurrentStreams(50),
Setting::InitialWindowSize(32768),
Setting::MaxFrameSize(32768),
]);
conn.process_frame(Frame::Settings(settings)).unwrap();
assert_eq!(conn.remote_settings().max_concurrent_streams, 50);
assert_eq!(conn.remote_settings().initial_window_size, 32768);
assert_eq!(conn.remote_settings().max_frame_size, 32768);
}
#[test]
fn test_settings_invalid_initial_window_size() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
let settings = SettingsFrame::new(vec![Setting::InitialWindowSize(0x8000_0000)]);
let err = conn.process_frame(Frame::Settings(settings)).unwrap_err();
assert_eq!(err.code, ErrorCode::FlowControlError);
}
#[test]
fn test_settings_invalid_max_frame_size() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
let settings = SettingsFrame::new(vec![Setting::MaxFrameSize(100)]); let err = conn.process_frame(Frame::Settings(settings)).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn test_settings_transitions_to_open() {
let mut conn = Connection::server(Settings::default());
assert_eq!(conn.state, ConnectionState::Handshaking);
let settings = SettingsFrame::new(vec![]);
conn.process_frame(Frame::Settings(settings)).unwrap();
assert_eq!(conn.state, ConnectionState::Open);
}
#[test]
fn test_goaway_rejects_new_streams() {
let mut conn = Connection::client(Settings::client());
conn.state = ConnectionState::Open;
let headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
conn.open_stream(headers.clone(), false).unwrap();
let goaway = Frame::GoAway(GoAwayFrame::new(1, ErrorCode::NoError));
conn.process_frame(goaway).unwrap();
assert!(conn.goaway_received());
assert_eq!(conn.state, ConnectionState::Closing);
let err = conn.open_stream(headers, false).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn test_goaway_sent_rejects_new_streams() {
let mut conn = Connection::client(Settings::client());
conn.state = ConnectionState::Open;
let headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
conn.goaway(ErrorCode::NoError, Bytes::new());
assert!(conn.goaway_sent);
assert_eq!(conn.state, ConnectionState::Closing);
let err = conn.open_stream(headers, false).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn test_goaway_resets_streams_above_last_id() {
let mut conn = Connection::client(Settings::client());
conn.state = ConnectionState::Open;
let headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
let stream1 = conn.open_stream(headers.clone(), false).unwrap(); let _ = conn.next_frame(); let stream3 = conn.open_stream(headers.clone(), false).unwrap(); let _ = conn.next_frame(); let stream5 = conn.open_stream(headers, false).unwrap(); let _ = conn.next_frame();
assert_eq!(stream1, 1);
assert_eq!(stream3, 3);
assert_eq!(stream5, 5);
let goaway = Frame::GoAway(GoAwayFrame::new(1, ErrorCode::NoError));
let result = conn.process_frame(goaway).unwrap().unwrap();
match result {
ReceivedFrame::GoAway {
last_stream_id,
error_code,
..
} => {
assert_eq!(last_stream_id, 1);
assert_eq!(error_code, ErrorCode::NoError);
}
_ => panic!("expected GoAway"),
}
assert!(!conn.stream(1).unwrap().state().is_closed());
assert_eq!(conn.stream(3).unwrap().state(), StreamState::Closed);
assert_eq!(conn.stream(5).unwrap().state(), StreamState::Closed);
}
#[test]
fn test_goaway_received_last_stream_id_only_narrows() {
let mut conn = Connection::client(Settings::client());
conn.state = ConnectionState::Open;
let headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
let _stream1 = conn.open_stream(headers.clone(), false).unwrap();
let _ = conn.next_frame();
let _stream3 = conn.open_stream(headers.clone(), false).unwrap();
let _ = conn.next_frame();
let _stream5 = conn.open_stream(headers.clone(), false).unwrap();
let _ = conn.next_frame();
let _stream7 = conn.open_stream(headers, false).unwrap();
let _ = conn.next_frame();
let first = conn
.process_frame(Frame::GoAway(GoAwayFrame::new(5, ErrorCode::NoError)))
.unwrap()
.unwrap();
match first {
ReceivedFrame::GoAway { last_stream_id, .. } => assert_eq!(last_stream_id, 5),
_ => panic!("expected GoAway"),
}
assert!(!conn.stream(5).unwrap().state().is_closed());
assert_eq!(conn.stream(7).unwrap().state(), StreamState::Closed);
let second = conn
.process_frame(Frame::GoAway(GoAwayFrame::new(7, ErrorCode::InternalError)))
.unwrap()
.unwrap();
match second {
ReceivedFrame::GoAway {
last_stream_id,
error_code,
..
} => {
assert_eq!(last_stream_id, 5);
assert_eq!(error_code, ErrorCode::InternalError);
}
_ => panic!("expected GoAway"),
}
assert!(!conn.stream(5).unwrap().state().is_closed());
let third = conn
.process_frame(Frame::GoAway(GoAwayFrame::new(1, ErrorCode::NoError)))
.unwrap()
.unwrap();
match third {
ReceivedFrame::GoAway { last_stream_id, .. } => assert_eq!(last_stream_id, 1),
_ => panic!("expected GoAway"),
}
assert_eq!(conn.stream(3).unwrap().state(), StreamState::Closed);
assert_eq!(conn.stream(5).unwrap().state(), StreamState::Closed);
}
#[test]
fn test_goaway_sent_once() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
conn.goaway(ErrorCode::NoError, Bytes::new());
assert!(conn.has_pending_frames());
conn.goaway(ErrorCode::InternalError, Bytes::new());
let frame1 = conn.next_frame().unwrap();
assert!(matches!(frame1, Frame::GoAway(_)));
assert!(!conn.has_pending_frames());
}
#[test]
fn goaway_refusal_boundary_stays_frozen_after_later_bookkeeping() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
let headers = Frame::Headers(HeadersFrame::new(1, Bytes::new(), false, true));
conn.process_frame(headers).unwrap();
conn.goaway(ErrorCode::NoError, Bytes::new());
let goaway = conn.next_frame().expect("expected GOAWAY frame");
match goaway {
Frame::GoAway(frame) => assert_eq!(frame.last_stream_id, 1),
other => panic!("expected GOAWAY frame, got {other:?}"),
}
let refused = Frame::Headers(HeadersFrame::new(3, Bytes::new(), false, true));
assert!(conn.process_frame(refused).unwrap().is_none());
let reset = Frame::RstStream(RstStreamFrame::new(3, ErrorCode::Cancel));
conn.process_frame(reset).unwrap();
assert_eq!(conn.last_stream_id, 3);
assert_eq!(conn.sent_goaway_last_stream_id, Some(1));
let refused_again = Frame::Headers(HeadersFrame::new(5, Bytes::new(), false, true));
assert!(
conn.process_frame(refused_again).unwrap().is_none(),
"streams above the advertised GOAWAY boundary must stay refused"
);
let mut refused_streams = Vec::new();
while let Some(frame) = conn.next_frame() {
if let Frame::RstStream(rst) = frame {
refused_streams.push(rst.stream_id);
}
}
assert_eq!(refused_streams, vec![3, 5]);
}
#[test]
fn test_goaway_with_debug_data() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
let debug_data = Bytes::from("server shutting down for maintenance");
conn.goaway(ErrorCode::NoError, debug_data.clone());
let frame = conn.next_frame().unwrap();
match frame {
Frame::GoAway(g) => {
assert_eq!(g.error_code, ErrorCode::NoError);
assert_eq!(g.debug_data, debug_data);
}
_ => panic!("expected GoAway"),
}
}
#[test]
fn test_goaway_received_with_error() {
let mut conn = Connection::client(Settings::client());
conn.state = ConnectionState::Open;
let goaway = Frame::GoAway(GoAwayFrame::new(0, ErrorCode::InternalError));
let result = conn.process_frame(goaway).unwrap().unwrap();
match result {
ReceivedFrame::GoAway {
error_code,
last_stream_id,
..
} => {
assert_eq!(error_code, ErrorCode::InternalError);
assert_eq!(last_stream_id, 0);
}
_ => panic!("expected GoAway"),
}
assert!(conn.goaway_received());
assert_eq!(conn.state, ConnectionState::Closing);
}
#[test]
fn test_graceful_shutdown_flow() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
conn.goaway(ErrorCode::NoError, Bytes::new());
assert_eq!(conn.state, ConnectionState::Closing);
let frame = conn.next_frame().unwrap();
match frame {
Frame::GoAway(g) => {
assert_eq!(g.error_code, ErrorCode::NoError);
}
_ => panic!("expected GoAway"),
}
}
#[test]
fn test_ping_ack_response() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
let opaque_data = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08];
let ping = PingFrame::new(opaque_data);
conn.process_frame(Frame::Ping(ping)).unwrap();
let frame = conn.next_frame().unwrap();
match frame {
Frame::Ping(p) => {
assert!(p.ack);
assert_eq!(p.opaque_data, opaque_data);
}
_ => panic!("expected Ping ACK"),
}
}
#[test]
fn test_ping_ack_not_echoed() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
let ping_ack = PingFrame::ack([1, 2, 3, 4, 5, 6, 7, 8]);
conn.process_frame(Frame::Ping(ping_ack)).unwrap();
assert!(!conn.has_pending_frames());
}
#[test]
fn test_rst_stream_on_idle_stream_is_connection_error() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
let rst = Frame::RstStream(RstStreamFrame::new(999, ErrorCode::Cancel));
let err = conn.process_frame(rst).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
assert!(
err.stream_id.is_none(),
"idle-stream RST_STREAM must be a connection error, not a stream error"
);
}
#[test]
fn test_rst_stream_on_open_stream() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
let headers = Frame::Headers(HeadersFrame::new(1, Bytes::new(), false, true));
conn.process_frame(headers).unwrap();
let rst = Frame::RstStream(RstStreamFrame::new(1, ErrorCode::Cancel));
let result = conn.process_frame(rst).unwrap().unwrap();
match result {
ReceivedFrame::Reset {
stream_id,
error_code,
} => {
assert_eq!(stream_id, 1);
assert_eq!(error_code, ErrorCode::Cancel);
}
_ => panic!("expected Reset"),
}
}
#[test]
fn test_rst_stream_on_stream_zero_is_connection_error() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
let rst = Frame::RstStream(RstStreamFrame::new(0, ErrorCode::Cancel));
let err = conn.process_frame(rst).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
assert!(err.stream_id.is_none());
}
#[test]
fn test_data_after_rst_ignored() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
let headers = Frame::Headers(HeadersFrame::new(1, Bytes::new(), false, true));
conn.process_frame(headers).unwrap();
let rst = Frame::RstStream(RstStreamFrame::new(1, ErrorCode::Cancel));
conn.process_frame(rst).unwrap();
assert_eq!(conn.stream(1).unwrap().state(), StreamState::Closed);
let data = Frame::Data(DataFrame::new(1, Bytes::from("test"), false));
let err = conn.process_frame(data).unwrap_err();
assert_eq!(err.code, ErrorCode::StreamClosed);
}
#[test]
fn test_reset_stream_drops_queued_outbound_data() {
let mut conn = Connection::client(Settings::client());
conn.state = ConnectionState::Open;
let headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
let stream_id = conn.open_stream(headers, false).unwrap();
let frame = conn.next_frame().expect("expected request HEADERS");
match frame {
Frame::Headers(h) => assert_eq!(h.stream_id, stream_id),
other => panic!("expected HEADERS frame, got {other:?}"),
}
conn.send_data(stream_id, Bytes::from("queued"), true)
.unwrap();
conn.reset_stream(stream_id, ErrorCode::Cancel);
let frame = conn.next_frame().expect("expected RST_STREAM frame");
match frame {
Frame::RstStream(rst) => {
assert_eq!(rst.stream_id, stream_id);
assert_eq!(rst.error_code, ErrorCode::Cancel);
}
other => panic!("expected RST_STREAM frame, got {other:?}"),
}
assert!(conn.next_frame().is_none());
}
#[test]
fn goaway_received_drops_queued_headers_for_refused_local_streams() {
let mut conn = Connection::client(Settings::client());
conn.state = ConnectionState::Open;
let headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
let stream1 = conn.open_stream(headers.clone(), false).unwrap();
let stream3 = conn.open_stream(headers, false).unwrap();
assert_eq!(stream1, 1);
assert_eq!(stream3, 3);
let goaway = Frame::GoAway(GoAwayFrame::new(1, ErrorCode::NoError));
conn.process_frame(goaway).unwrap();
assert_eq!(
conn.stream(stream3).unwrap().error_code(),
Some(ErrorCode::RefusedStream)
);
let frame = conn
.next_frame()
.expect("stream 1 HEADERS should still be sent");
match frame {
Frame::Headers(frame) => assert_eq!(frame.stream_id, stream1),
other => panic!("expected HEADERS frame, got {other:?}"),
}
assert!(
conn.next_frame().is_none(),
"queued HEADERS for reset stream 3 must be discarded"
);
}
#[test]
fn test_window_update_after_goaway() {
let mut conn = Connection::client(Settings::client());
conn.state = ConnectionState::Open;
let goaway = Frame::GoAway(GoAwayFrame::new(0, ErrorCode::NoError));
conn.process_frame(goaway).unwrap();
let window_update = Frame::WindowUpdate(WindowUpdateFrame::new(0, 1024));
let result = conn.process_frame(window_update);
assert!(result.is_ok());
}
#[test]
fn zero_increment_window_update_on_stream_is_stream_error() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
let headers = Frame::Headers(HeadersFrame::new(1, Bytes::new(), false, true));
conn.process_frame(headers).unwrap();
let wu = Frame::WindowUpdate(WindowUpdateFrame::new(1, 0));
let err = conn.process_frame(wu).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
assert_eq!(
err.stream_id,
Some(1),
"zero increment on a stream must be a stream error, not connection"
);
}
#[test]
fn zero_increment_window_update_on_connection_is_connection_error() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
let wu = Frame::WindowUpdate(WindowUpdateFrame::new(0, 0));
let err = conn.process_frame(wu).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
assert!(
err.stream_id.is_none(),
"zero increment on connection must be a connection error"
);
}
#[test]
fn final_inbound_data_does_not_queue_stream_window_update() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
let request = Frame::Headers(HeadersFrame::new(1, Bytes::new(), false, true));
conn.process_frame(request).unwrap();
let response_headers = vec![Header::new(":status", "200")];
conn.send_headers(1, response_headers, true).unwrap();
let _ = conn
.next_frame()
.expect("response HEADERS should be pending");
assert_eq!(
conn.stream(1).unwrap().state(),
StreamState::HalfClosedLocal
);
let payload_len = (DEFAULT_CONNECTION_WINDOW_SIZE / 2) + 2;
let data = Bytes::from(vec![0_u8; payload_len as usize]);
let inbound = Frame::Data(DataFrame::new(1, data, true));
conn.process_frame(inbound).unwrap();
assert_eq!(conn.stream(1).unwrap().state(), StreamState::Closed);
while let Some(frame) = conn.next_frame() {
if let Frame::WindowUpdate(update) = frame {
assert_ne!(
update.stream_id, 1,
"closed streams must not emit stream-level WINDOW_UPDATE"
);
}
}
}
#[test]
fn test_settings_during_continuation() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
let headers = Frame::Headers(HeadersFrame::new(1, Bytes::new(), false, false));
conn.process_frame(headers).unwrap();
assert!(conn.is_awaiting_continuation());
let settings = Frame::Settings(SettingsFrame::new(vec![]));
let err = conn.process_frame(settings).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn test_ping_during_continuation() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
let headers = Frame::Headers(HeadersFrame::new(1, Bytes::new(), false, false));
conn.process_frame(headers).unwrap();
assert!(conn.is_awaiting_continuation());
let ping = Frame::Ping(PingFrame::new([0; 8]));
let err = conn.process_frame(ping).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
}
#[test]
fn goaway_reflects_last_processed_stream() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
let headers = Frame::Headers(HeadersFrame::new(1, Bytes::new(), false, true));
conn.process_frame(headers).unwrap();
let headers = Frame::Headers(HeadersFrame::new(3, Bytes::new(), false, true));
conn.process_frame(headers).unwrap();
conn.goaway(ErrorCode::NoError, Bytes::new());
let frame = conn.next_frame().unwrap();
match frame {
Frame::GoAway(g) => {
assert_eq!(
g.last_stream_id, 3,
"GOAWAY should report highest processed stream ID"
);
}
_ => panic!("expected GoAway"),
}
}
#[test]
fn goaway_reflects_last_processed_data_stream() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
let headers = Frame::Headers(HeadersFrame::new(1, Bytes::new(), false, true));
conn.process_frame(headers).unwrap();
let data = Frame::Data(DataFrame::new(1, Bytes::from("hello"), false));
conn.process_frame(data).unwrap();
let headers = Frame::Headers(HeadersFrame::new(3, Bytes::new(), true, true));
conn.process_frame(headers).unwrap();
conn.goaway(ErrorCode::NoError, Bytes::new());
let mut goaway_frame = None;
while let Some(f) = conn.next_frame() {
if matches!(&f, Frame::GoAway(_)) {
goaway_frame = Some(f);
break;
}
}
match goaway_frame.unwrap() {
Frame::GoAway(g) => assert_eq!(g.last_stream_id, 3),
_ => panic!("expected GoAway"),
}
}
#[test]
fn continuation_frames_not_interleaved_with_pending_ops() {
let mut conn = Connection::client(Settings::client());
conn.state = ConnectionState::Open;
conn.remote_settings.max_frame_size = 50;
conn.pending_ops
.push_back(PendingOp::PingAck([9, 8, 7, 6, 5, 4, 3, 2]));
let mut headers = vec![
Header::new(":method", "GET"),
Header::new(":path", "/some/very/long/path/that/exceeds/frame/size"),
Header::new(":scheme", "https"),
Header::new(":authority", "example.com"),
];
for i in 0..10 {
headers.push(Header::new(
format!("x-custom-header-{i}"),
format!("value-{i}"),
));
}
let _ = conn.open_stream(headers, true).unwrap();
let frame1 = conn.next_frame().unwrap();
assert!(
matches!(frame1, Frame::Ping(_)),
"first frame should be the pre-existing PingAck"
);
let frame2 = conn.next_frame().unwrap();
match &frame2 {
Frame::Headers(h) => {
assert!(
!h.end_headers,
"headers too large, should have CONTINUATION"
);
}
other => panic!("expected HEADERS, got {other:?}"),
}
loop {
let frame = conn.next_frame();
match frame {
Some(Frame::Continuation(c)) => {
if c.end_headers {
break;
}
}
Some(other) => {
panic!("expected CONTINUATION but got {other:?} โ interleaving detected!")
}
None => panic!("ran out of frames before end_headers"),
}
}
}
#[test]
fn data_on_idle_stream_is_connection_error() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
let data = Frame::Data(DataFrame::new(1, Bytes::from("hello"), false));
let err = conn.process_frame(data).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
assert!(
err.stream_id.is_none(),
"idle-stream DATA must be a connection error, not a stream error"
);
}
#[test]
fn window_update_on_idle_stream_is_connection_error() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
let headers = Frame::Headers(HeadersFrame::new(1, Bytes::new(), false, true));
conn.process_frame(headers).unwrap();
let wu = Frame::WindowUpdate(WindowUpdateFrame::new(3, 1024));
let err = conn.process_frame(wu).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
assert!(
err.stream_id.is_none(),
"idle-stream WINDOW_UPDATE must be a connection error, not a stream error"
);
}
#[test]
fn rst_stream_flood_triggers_enhance_your_calm() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
for i in 0..DEFAULT_RST_STREAM_RATE_LIMIT {
let stream_id = i * 2 + 1;
let headers = Frame::Headers(HeadersFrame::new(stream_id, Bytes::new(), false, true));
conn.process_frame(headers).unwrap();
let rst = Frame::RstStream(RstStreamFrame::new(stream_id, ErrorCode::Cancel));
conn.process_frame(rst).unwrap();
}
let stream_id = DEFAULT_RST_STREAM_RATE_LIMIT * 2 + 1;
let headers = Frame::Headers(HeadersFrame::new(stream_id, Bytes::new(), false, true));
conn.process_frame(headers).unwrap();
let rst = Frame::RstStream(RstStreamFrame::new(stream_id, ErrorCode::Cancel));
let err = conn.process_frame(rst).unwrap_err();
assert_eq!(err.code, ErrorCode::EnhanceYourCalm);
assert!(
err.stream_id.is_none(),
"RST_STREAM flood must be a connection error"
);
}
#[test]
fn rst_stream_rate_limit_window_uses_time_getter() {
let _clock = lock_test_clock();
set_test_time_offset(Duration::ZERO);
let mut conn = Connection::server_with_time_getter(Settings::default(), test_now);
conn.state = ConnectionState::Open;
for i in 0..DEFAULT_RST_STREAM_RATE_LIMIT {
let stream_id = i * 2 + 1;
let headers = Frame::Headers(HeadersFrame::new(stream_id, Bytes::new(), false, true));
conn.process_frame(headers).unwrap();
let rst = Frame::RstStream(RstStreamFrame::new(stream_id, ErrorCode::Cancel));
conn.process_frame(rst).unwrap();
}
advance_test_time(Duration::from_millis(
u64::try_from(DEFAULT_RST_STREAM_RATE_WINDOW_MS).expect("window fits u64") + 1,
));
let stream_id = DEFAULT_RST_STREAM_RATE_LIMIT * 2 + 1;
let headers = Frame::Headers(HeadersFrame::new(stream_id, Bytes::new(), false, true));
conn.process_frame(headers).unwrap();
let rst = Frame::RstStream(RstStreamFrame::new(stream_id, ErrorCode::Cancel));
conn.process_frame(rst)
.expect("rate-limit window should reset");
assert_eq!(conn.rst_stream_count, 1);
}
#[test]
fn rst_stream_rate_limit_rejects_after_u32_max_without_wrapping() {
let mut conn =
Connection::server(Settings::default()).rst_stream_rate_limit(RstStreamRateLimit {
max_rst_streams: u32::MAX,
rst_window_ms: DEFAULT_RST_STREAM_RATE_WINDOW_MS,
});
conn.state = ConnectionState::Open;
for stream_id in [1, 3] {
let headers = Frame::Headers(HeadersFrame::new(stream_id, Bytes::new(), false, true));
conn.process_frame(headers).unwrap();
}
conn.rst_stream_count = u32::MAX - 1;
let rst = Frame::RstStream(RstStreamFrame::new(1, ErrorCode::Cancel));
conn.process_frame(rst)
.expect("u32::MAXth RST_STREAM should still be allowed");
assert_eq!(conn.rst_stream_count, u32::MAX);
let overflow_attempt = Frame::RstStream(RstStreamFrame::new(3, ErrorCode::Cancel));
let err = conn.process_frame(overflow_attempt).unwrap_err();
assert_eq!(err.code, ErrorCode::EnhanceYourCalm);
assert_eq!(conn.rst_stream_count, u32::MAX);
}
#[test]
fn headers_on_wrong_parity_stream_does_not_pollute_last_stream_id() {
let mut conn = Connection::server(Settings::default());
conn.state = ConnectionState::Open;
let headers = Frame::Headers(HeadersFrame::new(1, Bytes::new(), false, true));
conn.process_frame(headers).unwrap();
let invalid = Frame::Headers(HeadersFrame::new(2, Bytes::new(), false, true));
let err = conn.process_frame(invalid).unwrap_err();
assert_eq!(err.code, ErrorCode::ProtocolError);
conn.goaway(ErrorCode::NoError, Bytes::new());
let frame = conn.next_frame().unwrap();
match frame {
Frame::GoAway(g) => {
assert_eq!(
g.last_stream_id, 1,
"last_stream_id must not be bumped by rejected HEADERS"
);
}
_ => panic!("expected GoAway"),
}
}
#[test]
fn connection_window_update_rejects_zero_increment() {
let mut conn = Connection::server(Settings::default());
let err = conn.send_connection_window_update(0).unwrap_err();
assert_eq!(err.code, ErrorCode::FlowControlError);
}
#[test]
fn stream_window_update_rejects_zero_increment() {
let mut conn = Connection::server(Settings::default());
let err = conn.send_stream_window_update(1, 0).unwrap_err();
assert_eq!(err.code, ErrorCode::FlowControlError);
}
#[test]
fn connection_window_update_accepts_valid_increment() {
let mut conn = Connection::server(Settings::default());
assert!(conn.send_connection_window_update(1024).is_ok());
assert!(conn.has_pending_frames());
}
#[test]
fn stream_window_update_accepts_valid_increment() {
let mut conn = Connection::server(Settings::default());
let headers = Frame::Headers(HeadersFrame::new(1, Bytes::new(), false, true));
conn.process_frame(headers).unwrap();
assert!(conn.send_stream_window_update(1, 4096).is_ok());
assert!(conn.has_pending_frames());
}
}