use crate::{
Conn, Headers, HttpConfig, HttpContext, Method, Status,
h2::{
H2Driver, H2Error, H2ErrorCode, H2Transport,
acceptor::{recv::CLIENT_PREFACE, types::DriverState},
connection::H2Connection,
frame::{
FRAME_HEADER_LEN, Frame, FrameHeader, continuation as continuation_frame,
data as data_frame, goaway as goaway_frame, headers as headers_frame,
rst_stream as rst_stream_frame, settings, window_update as window_update_frame,
},
role::Role,
settings::H2Settings,
},
headers::{
header_observer::HeaderObserver,
hpack::{FieldSection, HpackEncoder, PseudoHeaders},
},
};
use std::{
sync::Arc,
task::{Context, Poll, Wake, Waker},
};
use trillium_testing::TestTransport;
pub(super) struct NoopWaker;
impl Wake for NoopWaker {
fn wake(self: Arc<Self>) {}
fn wake_by_ref(self: &Arc<Self>) {}
}
pub(super) fn noop_waker() -> Waker {
Waker::from(Arc::new(NoopWaker))
}
pub(super) struct CountingWaker(pub(super) std::sync::atomic::AtomicUsize);
impl Wake for CountingWaker {
fn wake(self: Arc<Self>) {
self.0.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
fn wake_by_ref(self: &Arc<Self>) {
self.0.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
}
impl CountingWaker {
pub(super) fn count(&self) -> usize {
self.0.load(std::sync::atomic::Ordering::SeqCst)
}
}
pub(super) fn counting_waker() -> (Arc<CountingWaker>, Waker) {
let counting = Arc::new(CountingWaker(std::sync::atomic::AtomicUsize::new(0)));
let waker = Waker::from(counting.clone());
(counting, waker)
}
pub(super) struct DriverFixture {
pub(super) driver: H2Driver<TestTransport>,
pub(super) connection: Arc<H2Connection>,
pub(super) peer: TestTransport,
peer_read_cursor: usize,
peer_hpack: HpackEncoder,
}
impl DriverFixture {
pub(super) fn new_server() -> Self {
Self::new_server_with_config(HttpConfig::default())
}
pub(super) fn new_server_with_config(config: HttpConfig) -> Self {
let (driver_transport, peer) = TestTransport::new();
let context = Arc::new(HttpContext {
config,
..HttpContext::default()
});
let connection = H2Connection::new(context);
let driver = connection.clone().run(driver_transport);
let peer_hpack = HpackEncoder::new(Arc::new(HeaderObserver::default()), 0, 0);
Self {
driver,
connection,
peer,
peer_read_cursor: 0,
peer_hpack,
}
}
pub(super) fn new_client() -> Self {
let (driver_transport, peer) = TestTransport::new();
let connection = H2Connection::new(Arc::new(HttpContext::new()));
let driver = H2Driver::new(connection.clone(), driver_transport, Role::Client);
let peer_hpack = HpackEncoder::new(Arc::new(HeaderObserver::default()), 0, 0);
Self {
driver,
connection,
peer,
peer_read_cursor: 0,
peer_hpack,
}
}
pub(super) fn peer_open_stream(
&mut self,
stream_id: u32,
method: Method,
path: &str,
end_stream: bool,
) {
let pseudos = PseudoHeaders::default()
.with_method(method)
.with_path(path)
.with_scheme("http")
.with_authority("test");
let headers = Headers::new();
let field_section = FieldSection::new(pseudos, &headers);
let mut block = Vec::new();
self.peer_hpack.encode(&field_section, &mut block);
let block_len = u32::try_from(block.len()).expect("block fits u32");
let mut frame = vec![0u8; FRAME_HEADER_LEN + block.len()];
headers_frame::encode_prefix(stream_id, end_stream, true, None, block_len, 0, &mut frame)
.expect("encode HEADERS prefix");
frame[FRAME_HEADER_LEN..].copy_from_slice(&block);
self.peer.write_all(&frame);
}
pub(super) fn peer_open_stream_no_end_headers(
&mut self,
stream_id: u32,
method: Method,
path: &str,
end_stream: bool,
) {
let pseudos = PseudoHeaders::default()
.with_method(method)
.with_path(path)
.with_scheme("http")
.with_authority("test");
let headers = Headers::new();
let field_section = FieldSection::new(pseudos, &headers);
let mut block = Vec::new();
self.peer_hpack.encode(&field_section, &mut block);
let block_len = u32::try_from(block.len()).expect("block fits u32");
let mut frame = vec![0u8; FRAME_HEADER_LEN + block.len()];
headers_frame::encode_prefix(stream_id, end_stream, false, None, block_len, 0, &mut frame)
.expect("encode HEADERS prefix");
frame[FRAME_HEADER_LEN..].copy_from_slice(&block);
self.peer.write_all(&frame);
}
pub(super) fn peer_open_stream_split(
&mut self,
stream_id: u32,
method: Method,
path: &str,
end_stream: bool,
split_at: usize,
) {
let pseudos = PseudoHeaders::default()
.with_method(method)
.with_path(path)
.with_scheme("http")
.with_authority("test");
let headers = Headers::new();
let field_section = FieldSection::new(pseudos, &headers);
let mut block = Vec::new();
self.peer_hpack.encode(&field_section, &mut block);
let split_at = split_at.min(block.len());
let head = &block[..split_at];
let head_len = u32::try_from(head.len()).expect("fragment fits u32");
let mut frame = vec![0u8; FRAME_HEADER_LEN + head.len()];
headers_frame::encode_prefix(stream_id, end_stream, false, None, head_len, 0, &mut frame)
.expect("encode HEADERS prefix");
frame[FRAME_HEADER_LEN..].copy_from_slice(head);
self.peer.write_all(&frame);
self.peer_continuation(stream_id, &block[split_at..], true);
}
pub(super) fn peer_continuation(&mut self, stream_id: u32, fragment: &[u8], end_headers: bool) {
let len = u32::try_from(fragment.len()).expect("fragment fits u32");
let mut frame = vec![0u8; continuation_frame::ENCODED_PREFIX_LEN + fragment.len()];
continuation_frame::encode_prefix(stream_id, end_headers, len, &mut frame)
.expect("encode CONTINUATION prefix");
frame[continuation_frame::ENCODED_PREFIX_LEN..].copy_from_slice(fragment);
self.peer.write_all(&frame);
}
pub(super) fn peer_response_headers(
&mut self,
stream_id: u32,
status: Status,
end_stream: bool,
) {
let pseudos = PseudoHeaders::default().with_status(status);
let headers = Headers::new();
let field_section = FieldSection::new(pseudos, &headers);
let mut block = Vec::new();
self.peer_hpack.encode(&field_section, &mut block);
let block_len = u32::try_from(block.len()).expect("block fits u32");
let mut frame = vec![0u8; FRAME_HEADER_LEN + block.len()];
headers_frame::encode_prefix(stream_id, end_stream, true, None, block_len, 0, &mut frame)
.expect("encode HEADERS prefix");
frame[FRAME_HEADER_LEN..].copy_from_slice(&block);
self.peer.write_all(&frame);
}
pub(super) fn peer_headers(
&mut self,
stream_id: u32,
pseudos: PseudoHeaders<'static>,
fields: &Headers,
end_stream: bool,
) {
let field_section = FieldSection::new(pseudos, fields);
let mut block = Vec::new();
self.peer_hpack.encode(&field_section, &mut block);
let block_len = u32::try_from(block.len()).expect("block fits u32");
let mut frame = vec![0u8; FRAME_HEADER_LEN + block.len()];
headers_frame::encode_prefix(stream_id, end_stream, true, None, block_len, 0, &mut frame)
.expect("encode HEADERS prefix");
frame[FRAME_HEADER_LEN..].copy_from_slice(&block);
self.peer.write_all(&frame);
}
pub(super) fn peer_trailers(&mut self, stream_id: u32, trailers: &Headers) {
let field_section = FieldSection::new(PseudoHeaders::default(), trailers);
let mut block = Vec::new();
self.peer_hpack.encode(&field_section, &mut block);
let block_len = u32::try_from(block.len()).expect("block fits u32");
let mut frame = vec![0u8; FRAME_HEADER_LEN + block.len()];
headers_frame::encode_prefix(stream_id, true, true, None, block_len, 0, &mut frame)
.expect("encode HEADERS prefix");
frame[FRAME_HEADER_LEN..].copy_from_slice(&block);
self.peer.write_all(&frame);
}
pub(super) fn peer_data(&mut self, stream_id: u32, payload: &[u8], end_stream: bool) {
let payload_len = u32::try_from(payload.len()).expect("data fits u32");
let mut frame = vec![0u8; FRAME_HEADER_LEN + payload.len()];
data_frame::encode_prefix(stream_id, end_stream, payload_len, 0, &mut frame)
.expect("encode DATA prefix");
frame[FRAME_HEADER_LEN..].copy_from_slice(payload);
self.peer.write_all(&frame);
}
pub(super) fn peer_rst_stream(&mut self, stream_id: u32, code: H2ErrorCode) {
let mut frame = vec![0u8; rst_stream_frame::ENCODED_LEN];
rst_stream_frame::encode(stream_id, code, &mut frame).expect("encode RST_STREAM");
self.peer.write_all(&frame);
}
pub(super) fn peer_goaway(&mut self, last_stream_id: u32, code: H2ErrorCode) {
let mut frame = vec![0u8; goaway_frame::encoded_len(0)];
goaway_frame::encode(last_stream_id, code, &[], &mut frame).expect("encode GOAWAY");
self.peer.write_all(&frame);
}
pub(super) fn peer_window_update(&mut self, stream_id: u32, increment: u32) {
let mut frame = vec![0u8; window_update_frame::ENCODED_LEN];
window_update_frame::encode(stream_id, increment, &mut frame)
.expect("encode WINDOW_UPDATE");
self.peer.write_all(&frame);
}
pub(super) fn complete_handshake(&mut self) {
self.complete_handshake_with_peer_settings(H2Settings::default());
}
pub(super) fn complete_handshake_with_peer_settings(&mut self, settings: H2Settings) {
self.peer.write_all(CLIENT_PREFACE);
let _ = self.tick();
if self.driver.state != DriverState::Running {
let _ = self.tick();
}
assert_eq!(
self.driver.state,
DriverState::Running,
"driver should reach Running after preface",
);
let mut buf = vec![0u8; settings::encoded_len(&settings)];
settings::encode(&settings, &mut buf).expect("encode settings");
self.peer.write_all(&buf);
let _ = self.tick();
let _ = self.next_outbound_bytes();
}
pub(super) fn complete_handshake_client(&mut self) {
let _ = self.tick();
if self.driver.state != DriverState::Running {
let _ = self.tick();
}
assert_eq!(
self.driver.state,
DriverState::Running,
"client should reach Running after writing its preface + SETTINGS",
);
let settings = H2Settings::default();
let mut buf = vec![0u8; settings::encoded_len(&settings)];
settings::encode(&settings, &mut buf).expect("encode settings");
self.peer.write_all(&buf);
let _ = self.tick();
let _ = self.next_outbound_bytes();
}
pub(super) fn tick(&mut self) -> Poll<Option<Result<Conn<H2Transport>, H2Error>>> {
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
self.driver.drive(&mut cx)
}
pub(super) fn next_outbound_bytes(&mut self) -> Vec<u8> {
let all = self.peer.snapshot();
if all.len() <= self.peer_read_cursor {
return Vec::new();
}
let bytes = all[self.peer_read_cursor..].to_vec();
self.peer_read_cursor = all.len();
bytes
}
pub(super) fn next_outbound_frames(&mut self) -> Vec<Frame> {
decode_frames(&self.next_outbound_bytes())
}
}
pub(super) fn decode_frames(bytes: &[u8]) -> Vec<Frame> {
let mut frames = Vec::new();
let mut offset = 0;
while offset < bytes.len() {
let header = FrameHeader::decode(&bytes[offset..]).expect("incomplete frame header");
let frame_len = FRAME_HEADER_LEN + header.length as usize;
let frame_bytes = &bytes[offset..offset + frame_len];
let (frame, _consumed) = Frame::decode(frame_bytes).expect("frame decode");
frames.push(frame);
offset += frame_len;
}
frames
}
pub(super) fn count_goaways(frames: &[Frame]) -> usize {
frames
.iter()
.filter(|f| matches!(f, Frame::Goaway { .. }))
.count()
}
#[test]
fn fixture_handshake_emits_settings_and_window_update() {
let mut fx = DriverFixture::new_server();
fx.peer.write_all(CLIENT_PREFACE);
let _ = fx.tick();
let _ = fx.tick();
let frames = fx.next_outbound_frames();
let settings_count = frames
.iter()
.filter(|f| matches!(f, Frame::Settings(_)))
.count();
let wu_count = frames
.iter()
.filter(|f| matches!(f, Frame::WindowUpdate { .. }))
.count();
assert!(
settings_count >= 1,
"expected initial SETTINGS in handshake outbound, got frames: {frames:?}",
);
assert!(
wu_count >= 1,
"expected initial WINDOW_UPDATE in handshake outbound, got frames: {frames:?}",
);
}
#[test]
fn peer_headers_opening_stream_yields_conn() {
let mut fx = DriverFixture::new_server();
fx.complete_handshake();
fx.peer_open_stream(1, Method::Get, "/", true);
let polled = fx.tick();
match polled {
Poll::Ready(Some(Ok(conn))) => {
assert_eq!(conn.method(), Method::Get);
assert_eq!(conn.path(), "/");
}
other => panic!("expected Ready(Some(Ok(conn))) yielding the new request, got {other:?}"),
}
}