mod closed_streams;
mod constants;
mod handler_signals;
mod outbound;
mod recv;
mod send;
mod types;
use super::{
H2Error, H2ErrorCode, connection::H2Connection, frame::FRAME_HEADER_LEN, role::Role,
transport::H2Transport,
};
use crate::{
Conn,
headers::hpack::{HpackDecoder, HpackEncoder},
};
use closed_streams::{ClosedReason, ClosedStreams};
use constants::{
INITIAL_CONNECTION_RECV_WINDOW, MAX_BUFFER_SIZE, MAX_DATA_CHUNK_SIZE, MAX_FLOW_CONTROL_WINDOW,
};
use futures_lite::io::{AsyncRead, AsyncWrite};
use hashbrown::HashMap;
use recv::PendingHeaders;
use std::{
future::Future,
io,
pin::Pin,
sync::Arc,
task::{Context, Poll, ready},
};
use swansong::ShuttingDown;
use types::{
AcceptorConfig, Action, CloseOutcome, DriverState, Next, ReadPhase, StreamEntry, frame_slice,
};
#[derive(Debug)]
pub struct H2Driver<T> {
connection: Arc<H2Connection>,
transport: T,
role: Role,
state: DriverState,
shutting_down: ShuttingDown,
read_buf: Vec<u8>,
read_filled: usize,
read_phase: ReadPhase,
write_buf: Vec<u8>,
write_cursor: usize,
write_flush_pending: bool,
hpack: HpackDecoder,
hpack_encoder: HpackEncoder,
streams: HashMap<u32, StreamEntry>,
last_peer_stream_id: u32,
pending_headers: Option<PendingHeaders>,
close_outcome: Option<CloseOutcome>,
finished: bool,
body_scratch: Vec<u8>,
connection_send_window: i64,
connection_recv_window: i64,
closed_streams: ClosedStreams,
pub(super) config: AcceptorConfig,
}
impl<T> H2Driver<T>
where
T: AsyncRead + AsyncWrite + Unpin + Send,
{
pub(super) fn new(connection: Arc<H2Connection>, transport: T, role: Role) -> Self {
let shutting_down = connection.swansong().shutting_down();
let context = connection.context();
let config = AcceptorConfig::from_http_config(context.config());
let hpack_encoder = HpackEncoder::new(
context.observer.clone(),
context.config.dynamic_table_capacity(),
context.config.recent_pairs_size(),
);
Self {
connection,
transport,
role,
state: DriverState::AwaitingPreface,
shutting_down,
read_buf: vec![0u8; FRAME_HEADER_LEN],
read_filled: 0,
read_phase: ReadPhase::NeedHeader,
write_buf: Vec::new(),
write_cursor: 0,
write_flush_pending: false,
hpack: HpackDecoder::new(config.hpack_table_capacity()),
hpack_encoder,
streams: HashMap::new(),
last_peer_stream_id: 0,
pending_headers: None,
close_outcome: None,
finished: false,
body_scratch: vec![0u8; MAX_DATA_CHUNK_SIZE as usize],
connection_send_window: INITIAL_CONNECTION_RECV_WINDOW,
connection_recv_window: INITIAL_CONNECTION_RECV_WINDOW,
closed_streams: ClosedStreams::default(),
config,
}
}
pub fn connection(&self) -> &Arc<H2Connection> {
&self.connection
}
#[allow(clippy::should_implement_trait)]
pub fn next(&mut self) -> Next<'_, T> {
Next { driver: self }
}
pub(super) fn drive(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Conn<H2Transport>, H2Error>>> {
if self.finished {
return Poll::Ready(None);
}
for loop_number in 0..self.config.copy_loops_per_yield() {
log::trace!("h2 drive loop number: {loop_number}");
self.service_handler_signals();
self.advance_outbound_sends(cx);
match self.poll_flush_outbound(cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(e)) => {
if self.close_outcome.is_none() {
self.close_outcome = Some(CloseOutcome::Io(e));
}
return Poll::Ready(self.finish_with_current_outcome());
}
Poll::Pending => return Poll::Pending,
}
if self.state == DriverState::Closing {
if matches!(self.close_outcome, Some(CloseOutcome::Io(_))) {
return Poll::Ready(self.finish_with_current_outcome());
}
self.state = DriverState::Drained;
}
if self.state == DriverState::Running
&& Pin::new(&mut self.shutting_down).poll(cx).is_ready()
{
self.begin_close(CloseOutcome::Graceful);
continue;
}
match self.state {
DriverState::AwaitingPreface => {
let poll = match self.role {
Role::Server => self.poll_read_preface(cx),
Role::Client => {
self.queue_client_preface();
Poll::Ready(Ok(()))
}
};
match poll {
Poll::Ready(Ok(())) => self.state = DriverState::NeedsServerSettings,
Poll::Ready(Err(e)) => {
self.close_outcome = Some(e);
return Poll::Ready(self.finish_with_current_outcome());
}
Poll::Pending => {
if self.park(cx) {
return Poll::Pending;
}
}
}
}
DriverState::NeedsServerSettings => {
self.queue_settings();
let raise = i64::from(self.config.initial_connection_window_size())
- INITIAL_CONNECTION_RECV_WINDOW;
if raise > 0 {
let raise = u32::try_from(raise).unwrap_or(u32::MAX);
self.queue_window_update(0, raise);
self.connection_recv_window += i64::from(raise);
}
self.state = DriverState::Running;
}
DriverState::Running => match self.poll_advance_read(cx) {
Poll::Ready(Ok(Action::Continue)) => {}
Poll::Ready(Ok(Action::Emit(conn))) => {
return Poll::Ready(Some(Ok(*conn)));
}
Poll::Ready(Ok(Action::Close(outcome))) => {
self.begin_close(outcome);
}
Poll::Ready(Err(e)) => {
self.begin_close(e);
}
Poll::Pending => {
if self.park(cx) {
return Poll::Pending;
}
}
},
DriverState::Closing => unreachable!("handled above once write_buf is drained"),
DriverState::Drained => match self.poll_drain_peer(cx) {
Poll::Ready(()) => {
return Poll::Ready(self.finish_with_current_outcome());
}
Poll::Pending => return Poll::Pending,
},
}
}
cx.waker().wake_by_ref();
Poll::Pending
}
fn park(&mut self, cx: &mut Context<'_>) -> bool {
self.connection.outbound_waker().register(cx.waker());
!self.has_pending_handler_signals() && !self.has_pending_outbound_progress()
}
fn finish_with_current_outcome(&mut self) -> Option<Result<Conn<H2Transport>, H2Error>> {
self.finished = true;
self.connection.fail_pending_pings(
io::ErrorKind::ConnectionAborted,
"h2 connection closed before PING ACK",
);
self.connection.wake_peer_settings_waiters();
match self.close_outcome.take() {
None | Some(CloseOutcome::Graceful) => None,
Some(CloseOutcome::Protocol(code)) => Some(Err(H2Error::Protocol(code))),
Some(CloseOutcome::Io(e)) => Some(Err(H2Error::Io(e))),
}
}
fn begin_close(&mut self, outcome: CloseOutcome) {
log::trace!("h2 driver: begin_close({outcome:?})");
let code = match &outcome {
CloseOutcome::Graceful => Some(H2ErrorCode::NoError),
CloseOutcome::Protocol(code) => Some(*code),
CloseOutcome::Io(_) => None,
};
if self.close_outcome.is_none() {
self.close_outcome = Some(outcome);
}
if let Some(code) = code {
self.queue_goaway(self.last_peer_stream_id, code);
}
self.state = DriverState::Closing;
}
fn poll_fill_to(&mut self, target: usize, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if self.read_buf.len() < target {
self.read_buf.resize(target, 0);
}
while self.read_filled < target {
let n = ready!(
Pin::new(&mut self.transport)
.poll_read(cx, &mut self.read_buf[self.read_filled..target])
)?;
if n == 0 {
return Poll::Ready(Err(io::Error::from(io::ErrorKind::UnexpectedEof)));
}
self.read_filled += n;
}
Poll::Ready(Ok(()))
}
fn poll_drain_peer(&mut self, cx: &mut Context<'_>) -> Poll<()> {
const MAX_DISCARD_ITERATIONS: usize = 256;
let mut scratch = [0u8; 512];
for _ in 0..MAX_DISCARD_ITERATIONS {
match Pin::new(&mut self.transport).poll_read(cx, &mut scratch) {
Poll::Ready(Ok(0) | Err(_)) | Poll::Pending => {
return Poll::Ready(());
}
Poll::Ready(Ok(_)) => {}
}
}
Poll::Ready(())
}
pub(super) fn closed_reason(&self, stream_id: u32) -> Option<ClosedReason> {
self.closed_streams.reason(stream_id)
}
}