use crate::{
Body, Conn, Headers, HttpContext, Method, Status,
h2::{
H2Driver, H2Error, H2ErrorCode, H2Transport,
acceptor::{
recv::CLIENT_PREFACE,
types::{CloseOutcome, DriverState},
},
connection::H2Connection,
frame::{
FRAME_HEADER_LEN, Frame, FrameHeader, data as data_frame, headers as headers_frame,
settings,
},
settings::H2Settings,
},
headers::{
header_observer::HeaderObserver,
hpack::{FieldSection, HpackEncoder, PseudoHeaders},
},
};
use std::{
sync::Arc,
task::{Context, Poll, Wake, Waker},
};
use trillium_testing::TestTransport;
struct NoopWaker;
impl Wake for NoopWaker {
fn wake(self: Arc<Self>) {}
fn wake_by_ref(self: &Arc<Self>) {}
}
fn noop_waker() -> Waker {
Waker::from(Arc::new(NoopWaker))
}
pub(super) struct DriverFixture {
pub(super) driver: H2Driver<TestTransport>,
pub(super) connection: Arc<H2Connection>,
pub(super) peer: TestTransport,
peer_read_cursor: usize,
peer_hpack: HpackEncoder,
}
impl DriverFixture {
pub(super) fn new_server() -> Self {
let (driver_transport, peer) = TestTransport::new();
let context = Arc::new(HttpContext::new());
let connection = H2Connection::new(context);
let driver = connection.clone().run(driver_transport);
let peer_hpack = HpackEncoder::new(Arc::new(HeaderObserver::default()), 0, 0);
Self {
driver,
connection,
peer,
peer_read_cursor: 0,
peer_hpack,
}
}
pub(super) fn peer_open_stream(
&mut self,
stream_id: u32,
method: Method,
path: &str,
end_stream: bool,
) {
let pseudos = PseudoHeaders::default()
.with_method(method)
.with_path(path)
.with_scheme("http")
.with_authority("test");
let headers = Headers::new();
let field_section = FieldSection::new(pseudos, &headers);
let mut block = Vec::new();
self.peer_hpack.encode(&field_section, &mut block);
let block_len = u32::try_from(block.len()).expect("block fits u32");
let mut frame = vec![0u8; FRAME_HEADER_LEN + block.len()];
headers_frame::encode_prefix(stream_id, end_stream, true, None, block_len, 0, &mut frame)
.expect("encode HEADERS prefix");
frame[FRAME_HEADER_LEN..].copy_from_slice(&block);
self.peer.write_all(&frame);
}
pub(super) fn peer_trailers(&mut self, stream_id: u32, trailers: &Headers) {
let field_section = FieldSection::new(PseudoHeaders::default(), trailers);
let mut block = Vec::new();
self.peer_hpack.encode(&field_section, &mut block);
let block_len = u32::try_from(block.len()).expect("block fits u32");
let mut frame = vec![0u8; FRAME_HEADER_LEN + block.len()];
headers_frame::encode_prefix(stream_id, true, true, None, block_len, 0, &mut frame)
.expect("encode HEADERS prefix");
frame[FRAME_HEADER_LEN..].copy_from_slice(&block);
self.peer.write_all(&frame);
}
pub(super) fn peer_data(&mut self, stream_id: u32, payload: &[u8], end_stream: bool) {
let payload_len = u32::try_from(payload.len()).expect("data fits u32");
let mut frame = vec![0u8; FRAME_HEADER_LEN + payload.len()];
data_frame::encode_prefix(stream_id, end_stream, payload_len, 0, &mut frame)
.expect("encode DATA prefix");
frame[FRAME_HEADER_LEN..].copy_from_slice(payload);
self.peer.write_all(&frame);
}
pub(super) fn complete_handshake(&mut self) {
self.peer.write_all(CLIENT_PREFACE);
let _ = self.tick();
if self.driver.state != DriverState::Running {
let _ = self.tick();
}
assert_eq!(
self.driver.state,
DriverState::Running,
"driver should reach Running after preface",
);
let empty_settings = H2Settings::default();
let mut buf = vec![0u8; settings::encoded_len(&empty_settings)];
settings::encode(&empty_settings, &mut buf).expect("encode settings");
self.peer.write_all(&buf);
let _ = self.tick();
let _ = self.next_outbound_bytes();
}
pub(super) fn tick(&mut self) -> Poll<Option<Result<Conn<H2Transport>, H2Error>>> {
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
self.driver.drive(&mut cx)
}
pub(super) fn next_outbound_bytes(&mut self) -> Vec<u8> {
let all = self.peer.snapshot();
if all.len() <= self.peer_read_cursor {
return Vec::new();
}
let bytes = all[self.peer_read_cursor..].to_vec();
self.peer_read_cursor = all.len();
bytes
}
pub(super) fn next_outbound_frames(&mut self) -> Vec<Frame> {
decode_frames(&self.next_outbound_bytes())
}
}
fn decode_frames(bytes: &[u8]) -> Vec<Frame> {
let mut frames = Vec::new();
let mut offset = 0;
while offset < bytes.len() {
let header = FrameHeader::decode(&bytes[offset..]).expect("incomplete frame header");
let frame_len = FRAME_HEADER_LEN + header.length as usize;
let frame_bytes = &bytes[offset..offset + frame_len];
let (frame, _consumed) = Frame::decode(frame_bytes).expect("frame decode");
frames.push(frame);
offset += frame_len;
}
frames
}
fn count_goaways(frames: &[Frame]) -> usize {
frames
.iter()
.filter(|f| matches!(f, Frame::Goaway { .. }))
.count()
}
#[test]
fn fixture_handshake_emits_settings_and_window_update() {
let mut fx = DriverFixture::new_server();
fx.peer.write_all(CLIENT_PREFACE);
let _ = fx.tick();
let _ = fx.tick();
let frames = fx.next_outbound_frames();
let settings_count = frames
.iter()
.filter(|f| matches!(f, Frame::Settings(_)))
.count();
let wu_count = frames
.iter()
.filter(|f| matches!(f, Frame::WindowUpdate { .. }))
.count();
assert!(
settings_count >= 1,
"expected initial SETTINGS in handshake outbound, got frames: {frames:?}",
);
assert!(
wu_count >= 1,
"expected initial WINDOW_UPDATE in handshake outbound, got frames: {frames:?}",
);
}
#[test]
fn peer_headers_opening_stream_yields_conn() {
let mut fx = DriverFixture::new_server();
fx.complete_handshake();
fx.peer_open_stream(1, Method::Get, "/", true);
let polled = fx.tick();
match polled {
Poll::Ready(Some(Ok(conn))) => {
assert_eq!(conn.method(), Method::Get);
assert_eq!(conn.path(), "/");
}
other => panic!("expected Ready(Some(Ok(conn))) yielding the new request, got {other:?}"),
}
}
#[test]
fn closing_to_drained_waits_for_in_flight_stream() {
let mut fx = DriverFixture::new_server();
fx.complete_handshake();
fx.peer_open_stream(1, Method::Post, "/", false);
let conn = match fx.tick() {
Poll::Ready(Some(Ok(conn))) => conn,
other => panic!("expected Conn yielded for stream 1, got {other:?}"),
};
let _conn_guard = conn;
fx.driver.begin_close(CloseOutcome::Graceful);
let _ = fx.tick();
assert_eq!(
fx.driver.state,
DriverState::Closing,
"in-flight stream's open recv side should hold the driver in Closing",
);
fx.peer_data(1, &[], true);
let _ = fx.tick();
assert_eq!(
fx.driver.state,
DriverState::Drained,
"with the last in-flight stream's recv side closed, Closing should advance to Drained",
);
}
#[test]
fn submit_trailers_lands_on_wire_after_body_parked() {
let mut fx = DriverFixture::new_server();
fx.complete_handshake();
fx.peer_open_stream(1, Method::Get, "/", true);
let _conn = match fx.tick() {
Poll::Ready(Some(Ok(conn))) => conn,
other => panic!("expected Conn yielded for stream 1, got {other:?}"),
};
let pseudos = PseudoHeaders::default().with_status(Status::Ok);
let _submit = fx.connection.submit_upgrade(1, pseudos, Headers::new());
let _ = fx.tick();
let headers_round = fx.next_outbound_frames();
assert!(
headers_round.iter().any(|f| matches!(
f,
Frame::Headers {
stream_id: 1,
end_stream: false,
..
}
)),
"response HEADERS (without END_STREAM) should be on the wire after first tick; got \
{headers_round:?}",
);
let mut trailers = Headers::new();
trailers.insert("grpc-status", "0");
fx.connection
.submit_trailers(1, trailers)
.expect("submit_trailers on a live stream");
let _ = fx.tick();
let trailing = fx.next_outbound_frames();
let trailing_headers = trailing
.iter()
.filter(|f| {
matches!(
f,
Frame::Headers {
stream_id: 1,
end_stream: true,
..
}
)
})
.count();
assert_eq!(
trailing_headers, 1,
"exactly one trailing HEADERS with END_STREAM should land on the wire after \
submit_trailers; got {trailing:?}",
);
}
#[test]
fn peer_end_stream_after_server_trailers_is_not_reset() {
let mut fx = DriverFixture::new_server();
fx.complete_handshake();
fx.peer_open_stream(1, Method::Post, "/", false);
let _conn = match fx.tick() {
Poll::Ready(Some(Ok(conn))) => conn,
other => panic!("expected Conn yielded for stream 1, got {other:?}"),
};
let pseudos = PseudoHeaders::default().with_status(Status::Ok);
let _submit = fx.connection.submit_upgrade(1, pseudos, Headers::new());
let _ = fx.tick();
let _ = fx.next_outbound_frames();
let mut trailers = Headers::new();
trailers.insert("grpc-status", "0");
fx.connection
.submit_trailers(1, trailers)
.expect("submit_trailers on a live stream");
let _ = fx.tick();
let trailing = fx.next_outbound_frames();
assert!(
trailing.iter().any(|f| matches!(
f,
Frame::Headers {
stream_id: 1,
end_stream: true,
..
}
)),
"server's trailing HEADERS with END_STREAM should be on the wire; got {trailing:?}",
);
fx.peer_data(1, &[], true);
let _ = fx.tick();
let after = fx.next_outbound_frames();
assert!(
!after
.iter()
.any(|f| matches!(f, Frame::RstStream { stream_id: 1, .. })),
"peer's END_STREAM on a half-closed-local stream must close cleanly, not earn a \
RST_STREAM; got {after:?}",
);
}
#[test]
fn send_pump_emits_response_in_closing() {
let mut fx = DriverFixture::new_server();
fx.complete_handshake();
fx.peer_open_stream(1, Method::Get, "/", true);
let _conn = match fx.tick() {
Poll::Ready(Some(Ok(conn))) => conn,
other => panic!("expected Conn yielded for stream 1, got {other:?}"),
};
let pseudos = PseudoHeaders::default().with_status(Status::Ok);
let body = Body::new_static(b"hi" as &[u8]);
let _submit = fx
.connection
.submit_send(1, pseudos, Headers::new(), Some(body));
fx.driver.begin_close(CloseOutcome::Graceful);
let _ = fx.tick();
let frames = fx.next_outbound_frames();
let response_headers = frames
.iter()
.filter(|f| matches!(f, Frame::Headers { stream_id: 1, .. }))
.count();
let data_frames = frames
.iter()
.filter(|f| matches!(f, Frame::Data { stream_id: 1, .. }))
.count();
assert!(
response_headers >= 1,
"send pump should emit response HEADERS for stream 1 while Closing; got {frames:?}",
);
assert!(
data_frames >= 1,
"send pump should emit DATA for stream 1 while Closing; got {frames:?}",
);
let end_stream_data = frames.iter().any(|f| {
matches!(
f,
Frame::Data {
stream_id: 1,
end_stream: true,
..
}
)
});
assert!(
end_stream_data,
"send pump should terminate stream 1 with END_STREAM; got {frames:?}",
);
}
#[test]
fn recv_pump_decodes_trailing_headers_in_closing() {
let mut fx = DriverFixture::new_server();
fx.complete_handshake();
fx.peer_open_stream(1, Method::Post, "/", false);
let _conn = match fx.tick() {
Poll::Ready(Some(Ok(conn))) => conn,
other => panic!("expected Conn yielded for stream 1, got {other:?}"),
};
let state = fx
.connection
.streams_lock()
.get(&1)
.cloned()
.expect("stream 1 registered");
fx.driver.begin_close(CloseOutcome::Graceful);
let _ = fx.tick();
assert_eq!(fx.driver.state, DriverState::Closing);
let mut trailers_in = Headers::new();
trailers_in.insert("grpc-status", "0");
trailers_in.insert("grpc-message", "ok");
fx.peer_trailers(1, &trailers_in);
let _ = fx.tick();
let stashed = state
.recv
.trailers
.lock()
.expect("recv.trailers mutex poisoned")
.clone()
.expect("driver should have stashed trailers from the post-GOAWAY frame");
assert_eq!(stashed.get_str("grpc-status"), Some("0"));
assert_eq!(stashed.get_str("grpc-message"), Some("ok"));
}
#[test]
fn closing_discards_new_stream_headers() {
let mut fx = DriverFixture::new_server();
fx.complete_handshake();
fx.peer_open_stream(1, Method::Post, "/", false);
let stream_one = match fx.tick() {
Poll::Ready(Some(Ok(conn))) => conn,
other => panic!("expected Conn yielded for stream 1, got {other:?}"),
};
fx.driver.begin_close(CloseOutcome::Graceful);
let _ = fx.tick();
assert_eq!(fx.driver.state, DriverState::Closing);
let _ = fx.next_outbound_bytes();
fx.peer_open_stream(3, Method::Get, "/late", true);
let polled = fx.tick();
assert!(
!matches!(polled, Poll::Ready(Some(Ok(_)))),
"post-GOAWAY HEADERS opening a new stream must not yield a Conn; got {polled:?}",
);
drop(stream_one);
}
#[test]
fn begin_close_is_idempotent() {
let mut fx = DriverFixture::new_server();
fx.complete_handshake();
assert_eq!(fx.driver.state, DriverState::Running);
fx.driver.begin_close(CloseOutcome::Graceful);
let _ = fx.tick();
assert_eq!(fx.driver.state, DriverState::Drained);
let first_round = fx.next_outbound_frames();
assert_eq!(
count_goaways(&first_round),
1,
"graceful begin_close should emit exactly one GOAWAY; got {first_round:?}",
);
let first_goaway_code = first_round.iter().find_map(|f| match f {
Frame::Goaway { error_code, .. } => Some(*error_code),
_ => None,
});
assert_eq!(
first_goaway_code,
Some(H2ErrorCode::NoError),
"graceful close should queue NoError, got {first_goaway_code:?}",
);
fx.driver
.begin_close(CloseOutcome::Protocol(H2ErrorCode::InternalError));
let _ = fx.tick();
let second_round = fx.next_outbound_frames();
assert_eq!(
count_goaways(&second_round),
0,
"second begin_close after Closing/Drained must not re-queue GOAWAY; got {second_round:?}",
);
}