use super::fixture::*;
use crate::{
Body, Buffer, Headers, KnownHeaderName, Method, ProtocolSession, ReceivedBody,
ReceivedBodyState, Status,
h2::{H2ErrorCode, H2Transport, SubmitSend, acceptor::types::DriverState, frame::Frame},
headers::hpack::PseudoHeaders,
};
use futures_lite::io::AsyncRead;
use std::{
future::Future,
net::Shutdown,
pin::Pin,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
task::{Context, Poll, Wake, Waker},
};
struct CountingWaker(AtomicUsize);
impl Wake for CountingWaker {
fn wake(self: Arc<Self>) {
self.0.fetch_add(1, Ordering::SeqCst);
}
fn wake_by_ref(self: &Arc<Self>) {
self.0.fetch_add(1, Ordering::SeqCst);
}
}
fn open_get(fx: &mut DriverFixture) -> (u32, SubmitSend, H2Transport) {
let pseudos = PseudoHeaders::default()
.with_method(Method::Get)
.with_path("/")
.with_scheme("http")
.with_authority("test");
let (id, submit, transport) = fx
.connection
.open_stream(pseudos, Headers::new(), None)
.expect("open_stream on a running client connection");
let _ = fx.tick();
let frames = fx.next_outbound_frames();
assert!(
frames.iter().any(|f| matches!(
f,
Frame::Headers {
stream_id,
end_stream: true,
..
} if *stream_id == id
)),
"body-less request should frame HEADERS(END_STREAM) on stream {id}; got {frames:?}",
);
(id, submit, transport)
}
#[test]
fn client_request_response_round_trip() {
let mut fx = DriverFixture::new_client();
fx.complete_handshake_client();
let (id, _submit, _transport) = open_get(&mut fx);
assert_eq!(id, 1, "first client-allocated stream id is 1");
fx.peer_response_headers(id, Status::Ok, true);
let _ = fx.tick();
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let mut resp = fx.connection.response_headers(id);
match Pin::new(&mut resp).poll(&mut cx) {
Poll::Ready(Ok(_fields)) => {}
other => {
panic!("response_headers should resolve Ok after the peer's HEADERS; got {other:?}")
}
}
}
#[test]
fn server_rst_on_in_flight_client_stream_surfaces_to_response_waiter() {
let mut fx = DriverFixture::new_client();
fx.complete_handshake_client();
let (id, _submit, _transport) = open_get(&mut fx);
fx.peer_rst_stream(id, H2ErrorCode::Cancel);
let _ = fx.tick();
assert!(
!fx.connection.streams_lock().contains_key(&id),
"server RST should remove the client stream from the map",
);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let mut resp = fx.connection.response_headers(id);
match Pin::new(&mut resp).poll(&mut cx) {
Poll::Ready(Err(_)) => {}
other => panic!("response_headers on a server-reset stream should error, got {other:?}"),
}
}
#[test]
fn server_goaway_resolves_pending_response_waiter() {
let mut fx = DriverFixture::new_client();
fx.complete_handshake_client();
let (id, _submit, _transport) = open_get(&mut fx);
let conn = fx.connection.clone();
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let mut resp = conn.response_headers(id);
assert!(
matches!(Pin::new(&mut resp).poll(&mut cx), Poll::Pending),
"no response yet — the waiter should park",
);
fx.peer_goaway(0, H2ErrorCode::NoError);
fx.peer.shutdown(Shutdown::Both);
for _ in 0..4 {
let _ = fx.tick();
}
match Pin::new(&mut resp).poll(&mut cx) {
Poll::Ready(Err(_)) => {}
other => panic!(
"response_headers must resolve with an error once the connection dies after GOAWAY \
(per its documented ConnectionAborted contract); got {other:?}",
),
}
}
#[test]
fn client_drains_buffered_rst_after_mirroring_peer_goaway() {
let mut fx = DriverFixture::new_client();
fx.complete_handshake_client();
let (id, _submit, _transport) = open_get(&mut fx);
fx.peer_goaway(1, H2ErrorCode::NoError);
fx.peer_rst_stream(id, H2ErrorCode::Cancel);
for _ in 0..4 {
let _ = fx.tick();
}
assert!(
!fx.connection.streams_lock().contains_key(&id),
"client must consume the buffered RST_STREAM and remove the stream; got it still present",
);
assert_eq!(
fx.driver.state,
DriverState::Drained,
"after mirroring GOAWAY and consuming the RST, the client's drain gate should clear",
);
}
#[test]
fn client_parked_in_closing_is_rewoken_by_late_rst() {
let mut fx = DriverFixture::new_client();
fx.complete_handshake_client();
let (id, _submit, _transport) = open_get(&mut fx);
let counter = Arc::new(CountingWaker(AtomicUsize::new(0)));
let waker = Waker::from(counter.clone());
let mut cx = Context::from_waker(&waker);
fx.peer_goaway(1, H2ErrorCode::NoError);
let mut polls = 0;
loop {
match fx.driver.drive(&mut cx) {
Poll::Pending => break,
Poll::Ready(Some(_)) => {}
Poll::Ready(None) => panic!("driver finished before the in-flight stream drained"),
}
polls += 1;
assert!(polls < 100, "driver never settled to Pending");
}
assert_eq!(
fx.driver.state,
DriverState::Closing,
"client should be Closing after mirroring the peer GOAWAY",
);
assert!(
fx.connection.streams_lock().contains_key(&id),
"the in-flight recv-open stream should still be holding the drain gate",
);
let wakes_before = counter.0.load(Ordering::SeqCst);
fx.peer_rst_stream(id, H2ErrorCode::Cancel);
let wakes_after = counter.0.load(Ordering::SeqCst);
assert!(
wakes_after > wakes_before,
"an RST arriving after the driver parked in Closing must re-wake the driver task (was \
{wakes_before}, now {wakes_after}); no wake means the lost-wake deadlock",
);
}
#[test]
fn server_rst_after_trailers_preserves_trailers_while_send_half_open() {
let mut fx = DriverFixture::new_client();
fx.complete_handshake_client();
let pseudos = PseudoHeaders::default()
.with_method(Method::Post)
.with_path("/")
.with_scheme("http")
.with_authority("test");
let (id, _submit, _transport) = fx
.connection
.open_stream(
pseudos,
Headers::new(),
Some(Body::new_static(vec![0u8; 70_000])),
)
.expect("open_stream on a running client connection");
let _ = fx.tick();
let _ = fx.next_outbound_frames();
fx.peer_response_headers(id, Status::Ok, false);
let _ = fx.tick();
let mut trailers = Headers::new();
trailers.insert("grpc-status", "0");
fx.peer_trailers(id, &trailers);
let _ = fx.tick();
assert!(
fx.connection.streams_lock().contains_key(&id),
"stream should still be tracked after receiving trailers (send half still open)",
);
fx.peer_rst_stream(id, H2ErrorCode::NoError);
let _ = fx.tick();
let recovered = fx.connection.take_trailers(id);
assert!(
recovered.is_some_and(|t| t.get_str("grpc-status") == Some("0")),
"trailers received before RST_STREAM(NoError) must be preserved, not discarded",
);
}
#[test]
fn server_data_after_its_own_end_stream_is_reset() {
let mut fx = DriverFixture::new_client();
fx.complete_handshake_client();
let (id, _submit, _transport) = open_get(&mut fx);
fx.peer_response_headers(id, Status::Ok, true);
let _ = fx.tick();
let _ = fx.next_outbound_frames();
fx.peer_data(id, b"extra", false);
let _ = fx.tick();
let frames = fx.next_outbound_frames();
assert!(
frames.iter().any(|f| matches!(
f,
Frame::RstStream {
stream_id,
error_code: H2ErrorCode::StreamClosed,
} if *stream_id == id
)),
"client must RST_STREAM(STREAM_CLOSED) on DATA after the server's END_STREAM; got \
{frames:?}",
);
}
#[test]
fn client_discards_interim_response_and_surfaces_final() {
let mut fx = DriverFixture::new_client();
fx.complete_handshake_client();
let (id, _submit, _transport) = open_get(&mut fx);
fx.peer_response_headers(id, Status::Continue, false);
let _ = fx.tick();
fx.peer_response_headers(id, Status::Ok, true);
let _ = fx.tick();
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let mut resp = fx.connection.response_headers(id);
match Pin::new(&mut resp).poll(&mut cx) {
Poll::Ready(Ok(fields)) => assert_eq!(
fields.pseudo_headers().status(),
Some(Status::Ok),
"the surfaced response must be the final 200, not the discarded interim 1xx",
),
other => panic!("expected the final response headers to surface, got {other:?}"),
}
}
#[test]
fn server_response_body_exceeding_content_length_is_reset() {
let mut fx = DriverFixture::new_client();
fx.complete_handshake_client();
let (id, _submit, _transport) = open_get(&mut fx);
let pseudos = PseudoHeaders::default().with_status(Status::Ok);
let mut fields = Headers::new();
fields.insert(KnownHeaderName::ContentLength, "1");
fx.peer_headers(id, pseudos, &fields, false);
let _ = fx.tick();
let _ = fx.next_outbound_frames();
fx.peer_data(id, b"test", true);
let _ = fx.tick();
let frames = fx.next_outbound_frames();
assert!(
frames.iter().any(|f| matches!(
f,
Frame::RstStream { stream_id, error_code: H2ErrorCode::ProtocolError }
if *stream_id == id
)),
"a response body past content-length must earn RST_STREAM(PROTOCOL_ERROR); got {frames:?}",
);
assert!(!fx.connection.streams_lock().contains_key(&id));
}
#[test]
fn server_response_body_matching_content_length_is_accepted() {
let mut fx = DriverFixture::new_client();
fx.complete_handshake_client();
let (id, _submit, _transport) = open_get(&mut fx);
let pseudos = PseudoHeaders::default().with_status(Status::Ok);
let mut fields = Headers::new();
fields.insert(KnownHeaderName::ContentLength, "4");
fx.peer_headers(id, pseudos, &fields, false);
let _ = fx.tick();
let _ = fx.next_outbound_frames();
fx.peer_data(id, b"test", true);
let _ = fx.tick();
let frames = fx.next_outbound_frames();
assert!(
!frames.iter().any(|f| matches!(
f,
Frame::RstStream { stream_id, error_code: H2ErrorCode::ProtocolError }
if *stream_id == id
)),
"a response body matching content-length must not be reset; got {frames:?}",
);
}
#[test]
fn client_interim_response_with_end_stream_aborts_waiter() {
let mut fx = DriverFixture::new_client();
fx.complete_handshake_client();
let (id, _submit, _transport) = open_get(&mut fx);
fx.peer_response_headers(id, Status::Continue, true);
let _ = fx.tick();
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let mut resp = fx.connection.response_headers(id);
match Pin::new(&mut resp).poll(&mut cx) {
Poll::Ready(Err(_)) => {}
other => panic!(
"an interim response with END_STREAM should abort the response waiter, not hang or \
surface a response; got {other:?}",
),
}
}
#[test]
fn content_length_response_defers_eof_until_end_stream_so_trailers_survive() {
let mut fx = DriverFixture::new_client();
fx.complete_handshake_client();
let (id, _submit, mut transport) = open_get(&mut fx);
let pseudos = PseudoHeaders::default().with_status(Status::Ok);
let mut fields = Headers::new();
fields.insert(KnownHeaderName::ContentLength, "0");
fx.peer_headers(id, pseudos, &fields, false);
let _ = fx.tick();
let mut buffer = Buffer::with_capacity(64);
let mut state = ReceivedBodyState::new_h2();
let mut received_trailers: Option<Headers> = None;
let mut body: ReceivedBody<'_, H2Transport> = ReceivedBody::new(
Some(0),
&mut buffer,
&mut transport,
&mut state,
None,
encoding_rs::UTF_8,
)
.with_protocol_session(ProtocolSession::Http2 {
connection: fx.connection.clone(),
stream_id: id,
})
.with_trailers(&mut received_trailers);
let waker = Waker::from(Arc::new(CountingWaker(AtomicUsize::new(0))));
let mut cx = Context::from_waker(&waker);
let mut buf = [0u8; 16];
assert!(
Pin::new(&mut body)
.poll_read(&mut cx, &mut buf)
.is_pending(),
"body must not declare EOF on content-length alone before END_STREAM (would lose trailers)",
);
let mut trailers = Headers::new();
trailers.insert("grpc-status", "0");
fx.peer_trailers(id, &trailers);
let _ = fx.tick();
assert!(
matches!(
Pin::new(&mut body).poll_read(&mut cx, &mut buf),
Poll::Ready(Ok(0))
),
"body should reach clean EOF once END_STREAM has arrived",
);
assert_eq!(
received_trailers
.as_ref()
.and_then(|t| t.get_str("grpc-status")),
Some("0"),
"trailers delivered with END_STREAM must be surfaced through the body, not dropped",
);
}