use crate::{
http::{ReqResBuffer, StatusCode},
http2::{
Http2Buffer, Http2Error, Http2ErrorCode, Http2Inner, Http2Params, Http2RecvStatus,
Http2SendStatus, Scrp, Sorp, Windows,
common_flags::CommonFlags,
frame_init::{FrameInit, FrameInitTy},
go_away_frame::GoAwayFrame,
headers_frame::HeadersFrame,
hpack_decoder::HpackDecoder,
http2_data::Http2DataPartsMut,
reader_data::ReaderData,
reset_stream_frame::ResetStreamFrame,
stream_receiver::{StreamControlRecvParams, StreamOverallRecvParams},
stream_state::StreamState,
u31::U31,
},
misc::{
ConnectionState, LeaseMut, Usize,
net::{read_header, read_payload},
},
stream::{StreamReader, StreamWriter},
sync::{AtomicBool, AtomicWaker},
};
use core::{
future::poll_fn,
mem,
pin::pin,
sync::atomic::Ordering,
task::{Context, Poll},
};
pub(crate) fn check_content_length(
content_length: Option<usize>,
rrb: &ReqResBuffer,
) -> crate::Result<()> {
let Some(elem) = content_length else {
return Ok(());
};
if rrb.body.len() != elem {
return Err(protocol_err(Http2Error::InvalidContentLength));
}
Ok(())
}
pub(crate) fn frame_reader_rslt(err: &mut Option<crate::Error>) -> crate::Result<()> {
match err.take() {
Some(elem) => Err(elem),
None => Ok(()),
}
}
pub(crate) fn connection_state(atomic_bool: &AtomicBool) -> ConnectionState {
ConnectionState::from(atomic_bool.load(Ordering::Relaxed))
}
#[track_caller]
pub(crate) fn scrp_mut(
scrp: &mut Scrp,
stream_id: U31,
) -> crate::Result<&mut StreamControlRecvParams> {
scrp.get_mut(&stream_id).ok_or_else(|| protocol_err(Http2Error::UnknownStreamId))
}
#[track_caller]
pub(crate) fn sorp_mut(
sorp: &mut Sorp,
stream_id: U31,
) -> crate::Result<&mut StreamOverallRecvParams> {
sorp.get_mut(&stream_id).ok_or_else(|| protocol_err(Http2Error::UnknownStreamId))
}
pub(crate) fn manage_recurrent_receiving_of_overall_stream<EOS, const IS_CLIENT: bool>(
cx: &mut Context<'_>,
mut hdpm: Http2DataPartsMut<'_, IS_CLIENT>,
is_conn_open: &AtomicBool,
stream_id: U31,
cb_eos: impl FnOnce(&mut Http2DataPartsMut<'_, IS_CLIENT>, StatusCode, StreamState, Windows) -> EOS,
) -> Poll<crate::Result<(Http2RecvStatus<EOS, ()>, ReqResBuffer)>> {
macro_rules! eos {
($hdpm:expr, $hrs:ident, $sorp:expr, $stream_id:expr) => {{
let content_length = $sorp.content_length;
let rrb = mem::take(&mut $sorp.rrb);
let status_code = $sorp.status_code;
let stream_state = $sorp.stream_state;
let windows = $sorp.windows;
drop($hdpm.hb.sorps.remove($stream_id));
check_content_length(content_length, &rrb)?;
let eos = cb_eos(&mut $hdpm, status_code, stream_state, windows);
Poll::Ready(Ok((Http2RecvStatus::$hrs(eos), rrb)))
}};
}
let sorp = sorp_mut(&mut hdpm.hb.sorps, stream_id)?;
match (connection_state(is_conn_open), sorp.is_stream_open) {
(ConnectionState::Closed, false | true) => {
let rrb = mem::take(&mut sorp.rrb);
drop(hdpm.hb.sorps.remove(&stream_id));
frame_reader_rslt(hdpm.frame_reader_error)?;
return Poll::Ready(Ok((Http2RecvStatus::ClosedConnection, rrb)));
}
(ConnectionState::Open, false) => return eos!(hdpm, ClosedStream, sorp, &stream_id),
(ConnectionState::Open, true) => {}
}
if sorp.stream_state.recv_eos() {
return eos!(hdpm, Eos, sorp, &stream_id);
}
sorp.waker.clone_from(cx.waker());
Poll::Pending
}
pub(crate) const fn protocol_err(error: Http2Error) -> crate::Error {
crate::Error::Http2ErrorGoAway(Http2ErrorCode::ProtocolError, error)
}
pub(crate) async fn process_higher_operation_err<HB, SW, const IS_CLIENT: bool>(
err: &crate::Error,
inner: &Http2Inner<HB, SW, IS_CLIENT>,
) where
HB: LeaseMut<Http2Buffer>,
SW: StreamWriter,
{
match err {
crate::Error::Http2ErrorGoAway(http2_error_code, _) => {
send_go_away(*http2_error_code, inner).await;
}
crate::Error::Http2FlowControlError(_, stream_id) => {
let _ = send_reset_stream(Http2ErrorCode::FlowControlError, inner, stream_id.into()).await;
}
_ => {
send_go_away(Http2ErrorCode::InternalError, inner).await;
}
}
}
pub(crate) async fn read_frame<SR, const IS_HEADER_BLOCK: bool>(
is_conn_open: &AtomicBool,
max_frame_len: u32,
rd: &mut ReaderData<SR>,
read_frame_waker: &AtomicWaker,
) -> crate::Result<Option<FrameInit>>
where
SR: StreamReader,
{
let mut fut = pin!(async move {
for _ in 0.._max_frames_mismatches!() {
rd.pfb.clear_if_following_is_empty();
rd.pfb.reserve(9)?;
let mut read = rd.pfb.following_len();
let buffer = rd.pfb.following_rest_mut();
let array = read_header::<0, 9, _>(buffer, &mut read, &mut rd.stream_reader).await?;
let (fi_opt, data_len) = FrameInit::from_array(array);
if data_len > max_frame_len {
return Err(crate::Error::Http2ErrorGoAway(
Http2ErrorCode::FrameSizeError,
Http2Error::LargeArbitraryFrameLen { received: data_len },
));
}
let data_len_usize = *Usize::from_u32(data_len);
let Some(fi) = fi_opt else {
if IS_HEADER_BLOCK {
return Err(protocol_err(Http2Error::UnexpectedContinuationFrame));
}
if data_len > 32 {
return Err(protocol_err(Http2Error::LargeIgnorableFrameLen));
}
let frame_len = data_len_usize.wrapping_add(9);
let (antecedent_len, following_len) = if let Some(to_read) = frame_len.checked_sub(read) {
rd.stream_reader.read_skip(to_read).await?;
(rd.pfb.all().len(), 0)
} else {
(rd.pfb.current_end_idx().wrapping_add(frame_len), read.wrapping_sub(frame_len))
};
rd.pfb.set_indices(antecedent_len, 0, following_len)?;
continue;
};
_trace!("Received frame: {fi:?}");
read_payload((9, data_len_usize), &mut rd.pfb, &mut read, &mut rd.stream_reader).await?;
return Ok(fi);
}
Err(protocol_err(Http2Error::VeryLargeAmountOfFrameMismatches))
});
poll_fn(|cx| match fut.as_mut().poll(cx) {
Poll::Ready(Ok(fi)) => Poll::Ready(Ok(Some(fi))),
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Pending => {
read_frame_waker.register(cx.waker());
if connection_state(is_conn_open).is_closed() {
return Poll::Ready(Ok(None));
}
Poll::Pending
}
})
.await
}
pub(crate) async fn read_header_and_continuations<
H,
SR,
const IS_CLIENT: bool,
const IS_TRAILER: bool,
>(
fi: FrameInit,
is_conn_open: &AtomicBool,
hp: &mut Http2Params,
hpack_dec: &mut HpackDecoder,
rd: &mut ReaderData<SR>,
read_frame_waker: &AtomicWaker,
rrb: &mut ReqResBuffer,
mut headers_cb: impl FnMut(&HeadersFrame<'_>) -> crate::Result<H>,
) -> crate::Result<(Option<usize>, bool, H)>
where
SR: StreamReader,
{
if IS_TRAILER && !fi.cf.has_eos() {
return Err(protocol_err(Http2Error::MissingEOSInTrailer));
}
let rrb_body_start = if IS_TRAILER {
rrb.body.len()
} else {
rrb.clear();
0
};
if fi.cf.has_eoh() {
let (content_length, hf) = HeadersFrame::read::<IS_CLIENT, IS_TRAILER>(
Some(rd.pfb.current()),
fi,
hp,
hpack_dec,
(rrb, rrb_body_start),
)?;
if hf.is_over_size() {
return Err(crate::Error::Http2ErrorGoAway(
Http2ErrorCode::FrameSizeError,
Http2Error::VeryLargeHeadersLen,
));
}
return Ok((content_length, hf.has_eos(), headers_cb(&hf)?));
}
rrb.body.extend_from_copyable_slice(rd.pfb.current())?;
'continuation_frames: {
for _ in 0.._max_continuation_frames!() {
let Some(frame_fi) =
read_frame::<_, true>(is_conn_open, hp.max_frame_len(), rd, read_frame_waker).await?
else {
return Err(protocol_err(Http2Error::IncompleteHeader));
};
let has_diff_id = fi.stream_id != frame_fi.stream_id;
let is_not_continuation = frame_fi.ty != FrameInitTy::Continuation;
if has_diff_id || is_not_continuation {
return Err(protocol_err(Http2Error::UnexpectedContinuationFrame));
}
rrb.body.extend_from_copyable_slice(rd.pfb.current())?;
if frame_fi.cf.has_eoh() {
break 'continuation_frames;
}
}
return Err(protocol_err(Http2Error::VeryLargeAmountOfContinuationFrames));
}
let (content_length, hf) =
HeadersFrame::read::<IS_CLIENT, IS_TRAILER>(None, fi, hp, hpack_dec, (rrb, rrb_body_start))?;
if IS_TRAILER {
rrb.body.truncate(rrb_body_start);
} else {
rrb.clear();
}
if hf.is_over_size() {
return Err(crate::Error::Http2ErrorGoAway(
Http2ErrorCode::FrameSizeError,
Http2Error::VeryLargeHeadersLen,
));
}
Ok((content_length, hf.has_eos(), headers_cb(&hf)?))
}
pub(crate) async fn send_go_away<HB, SW, const IS_CLIENT: bool>(
error_code: Http2ErrorCode,
inner: &Http2Inner<HB, SW, IS_CLIENT>,
) where
HB: LeaseMut<Http2Buffer>,
SW: StreamWriter,
{
let last_stream_id = {
let mut hd_guard = inner.hd.lock().await;
let hdpm = hd_guard.parts_mut();
inner.is_conn_open.store(false, Ordering::Relaxed);
while let Some(elem) = hdpm.hb.initial_server_streams_local.pop_front() {
elem.wake();
}
for (_, value) in hdpm.hb.scrps.drain() {
value.waker.wake();
}
for (_, value) in hdpm.hb.sorps.drain() {
value.waker.wake();
}
inner.read_frame_waker.wake();
*hdpm.last_stream_id
};
let gaf = GoAwayFrame::new(error_code, last_stream_id);
let _rslt = inner.wd.lock().await.stream_writer.write_all(&gaf.bytes()).await;
}
pub(crate) async fn send_reset_stream<HB, SW, const IS_CLIENT: bool>(
error_code: Http2ErrorCode,
inner: &Http2Inner<HB, SW, IS_CLIENT>,
stream_id: U31,
) -> bool
where
HB: LeaseMut<Http2Buffer>,
SW: StreamWriter,
{
let mut has_stored = false;
let _rslt = inner
.wd
.lock()
.await
.stream_writer
.write_all(&ResetStreamFrame::new(error_code, stream_id).bytes())
.await;
let mut hd_guard = inner.hd.lock().await;
if let Some(elem) = hd_guard.parts_mut().hb.scrps.get_mut(&stream_id) {
has_stored = true;
elem.is_stream_open = false;
elem.stream_state = StreamState::Closed;
elem.waker.wake_by_ref();
}
if let Some(elem) = hd_guard.parts_mut().hb.sorps.get_mut(&stream_id) {
has_stored = true;
elem.is_stream_open = false;
elem.stream_state = StreamState::Closed;
elem.waker.wake_by_ref();
}
has_stored
}
pub(crate) const fn server_header_stream_state(has_eos: bool) -> StreamState {
if has_eos { StreamState::HalfClosedRemote } else { StreamState::Open }
}
pub(crate) fn status_recv<EOS, ONG>(
is_conn_open: &AtomicBool,
sorp: &mut StreamOverallRecvParams,
eos_cb: impl FnOnce(&mut StreamOverallRecvParams) -> crate::Result<EOS>,
) -> crate::Result<Option<Http2RecvStatus<EOS, ONG>>> {
if connection_state(is_conn_open).is_closed() {
return Ok(Some(Http2RecvStatus::ClosedConnection));
}
if !sorp.is_stream_open {
return Ok(Some(Http2RecvStatus::ClosedStream(eos_cb(sorp)?)));
}
if sorp.stream_state.recv_eos() {
return Ok(Some(Http2RecvStatus::Eos(eos_cb(sorp)?)));
}
Ok(None)
}
pub(crate) fn status_send<const IS_CLIENT: bool>(
is_conn_open: &AtomicBool,
sorp: &StreamOverallRecvParams,
) -> Option<Http2SendStatus> {
if connection_state(is_conn_open).is_closed() {
return Some(Http2SendStatus::ClosedConnection);
}
if !sorp.is_stream_open {
return Some(Http2SendStatus::ClosedStream);
}
if !sorp.stream_state.can_send::<IS_CLIENT>() {
return Some(Http2SendStatus::InvalidState);
}
None
}
pub(crate) fn trim_frame_pad(cf: CommonFlags, data: &mut &[u8]) -> crate::Result<Option<u8>> {
let mut pad_len = None;
if cf.has_pad() {
let [local_pad_len, rest @ ..] = data else {
return Err(protocol_err(Http2Error::InvalidFramePad));
};
let idx_opt = rest.len().checked_sub(usize::from(*local_pad_len));
let Some(local_data) = idx_opt.and_then(|idx| rest.get(..idx)) else {
return Err(protocol_err(Http2Error::InvalidFramePad));
};
*data = local_data;
pad_len = Some(*local_pad_len);
}
Ok(pad_len)
}
pub(crate) async fn write_array<SW, const N: usize>(
array: [&[u8]; N],
is_conn_open: &AtomicBool,
stream_writer: &mut SW,
) -> crate::Result<()>
where
SW: StreamWriter,
{
if connection_state(is_conn_open).is_closed() {
return Ok(());
}
_trace!("Sending frame(s): {:?}", {
let process = |elem: &mut Option<_>, frame: &[u8]| {
let [a, b, c, d, e, f, g, h, i, rest @ ..] = frame else {
return;
};
if rest.len() > 36 {
return;
}
let (Some(fi), _) = FrameInit::from_array([*a, *b, *c, *d, *e, *f, *g, *h, *i]) else {
return;
};
*elem = Some(fi);
};
let mut rslt = [None; N];
let mut iter = rslt.iter_mut().zip(array.iter());
if let Some((elem, frame)) = iter.next()
&& frame != crate::http2::PREFACE
{
process(elem, frame);
}
for (elem, frame) in iter {
process(elem, frame);
}
rslt
});
stream_writer.write_all_vectored(&array).await?;
Ok(())
}