use core::{
cell::{RefCell, RefMut},
fmt,
future::{Future, poll_fn},
mem,
net::SocketAddr,
pin::{Pin, pin},
task::{Context, Poll, Waker, ready},
};
use std::{
collections::{HashMap, VecDeque},
io,
net::Shutdown,
rc::Rc,
time::Duration,
};
use tracing::error;
use xitca_io::{
bytes::{Buf, BufMut, BytesMut},
io::{AsyncBufRead, AsyncBufWrite, BoundedBuf, write_all},
};
use xitca_service::Service;
use xitca_unsafe_collection::{
futures::{Select, SelectOutput},
no_hash::NoHashBuilder,
};
use crate::{
body::{Body, SizeHint},
bytes::Bytes,
config::HttpServiceConfig,
date::{DateTime, DateTimeHandle},
http::{
Extension, HeaderMap, Method, Protocol, Request, RequestExt, Response, Uri, Version,
header::{CONTENT_LENGTH, DATE},
uri,
},
util::{
futures::Queue,
timer::{KeepAlive, Timeout},
},
};
use super::{
STREAM_MUST_EXIST,
body::RequestBody,
proto::{
error::Error as ProtoError,
frame::{
PREFACE,
data::Data,
go_away::GoAway,
head,
headers::{self, ResponsePseudo},
ping::Ping,
reason::Reason,
reset::Reset,
settings::{self, Settings},
stream_id::StreamId,
window_update::WindowUpdate,
},
hpack,
last_stream_id::LastStreamId,
ping_pong::PingPong,
reset_counter::ResetCounter,
size::BodySize,
stream::{RecvData, Stream, StreamError, TryRemove},
threshold::StreamRecvWindowThreshold,
},
};
struct DecodeContext<'a, S> {
max_header_list_size: usize,
max_concurrent_streams: usize,
recv_threshold: StreamRecvWindowThreshold,
decoder: hpack::Decoder,
next_frame_len: usize,
continuation: Option<(headers::Headers, BytesMut)>,
service: &'a S,
ctx: &'a Shared,
addr: SocketAddr,
date: &'a DateTimeHandle,
}
pub(super) type Frame = crate::body::Frame<Bytes>;
pub(super) type FrameBuffer = super::util::FrameBuffer<Frame>;
#[derive(PartialEq)]
enum KeepalivePing {
Idle,
Pending,
InFlight,
}
pub(super) struct FlowControl {
send_connection_window: usize,
send_stream_initial_window: i64,
max_frame_size: usize,
recv_stream_initial_window: usize,
recv_connection_window: usize,
pub(super) stream_map: HashMap<StreamId, Stream, NoHashBuilder>,
reset_counter: ResetCounter,
last_stream_id: LastStreamId,
pub(super) queue: WriterQueue,
pub(super) frame_buf: FrameBuffer,
}
impl FlowControl {
fn insert_stream(&mut self, id: StreamId, end_stream: bool, content_length: SizeHint) {
let stream = Stream::new(
self.send_stream_initial_window,
self.max_frame_size,
self.recv_stream_initial_window,
content_length,
end_stream,
);
self.stream_map.insert(id, stream);
}
fn check_not_idle(&self, id: StreamId) -> Result<(), Error> {
if self.last_stream_id.check_idle(id) {
Err(Error::GoAway(Reason::PROTOCOL_ERROR))
} else {
Ok(())
}
}
fn recv_window_dec(&mut self, len: usize) -> Result<(), Error> {
self.recv_connection_window = self
.recv_connection_window
.checked_sub(len)
.ok_or(Error::GoAway(Reason::FLOW_CONTROL_ERROR))?;
Ok(())
}
fn update_and_wake_send_streams(&mut self, delta: i64) {
for state in self.stream_map.values_mut() {
let sf = &mut state.send;
sf.window += delta;
if sf.window > 0 {
sf.wake();
}
}
}
#[cold]
#[inline(never)]
fn try_reset_stream(&mut self, id: StreamId) -> Result<(), Error> {
self.try_tick_reset()?;
let Some(stream) = self.stream_map.get_mut(&id) else {
return Ok(());
};
stream.try_set_peer_reset();
self.inline_remove_stream(id)
}
pub(super) fn request_body_drop(&mut self, id: StreamId, pending_window: usize) {
let stream = self.stream_map.get_mut(&id).expect(STREAM_MUST_EXIST);
let window = stream.maybe_close_recv(&mut self.frame_buf) + pending_window;
let mut wake = window > 0;
self.queue.connection_window_update(window);
let remove = stream.try_remove();
if let Err(err) = self.remove_stream(id, remove) {
wake = true;
if let Err(err) = self.try_push_reset(id, err.reason()) {
self.go_away(err);
}
}
if wake {
self.queue.wake();
}
}
fn response_task_done(&mut self, id: StreamId) -> Result<(), ()> {
let stream = self.stream_map.get_mut(&id).expect(STREAM_MUST_EXIST);
stream.close_send();
let remove = stream.try_remove();
if let Err(err) = self.remove_stream(id, remove) {
if let Err(err) = self.try_push_reset(id, err.reason()) {
self.go_away(err);
return Err(());
}
}
Ok(())
}
fn inline_remove_stream(&mut self, id: StreamId) -> Result<(), Error> {
if let Some(stream) = self.stream_map.get_mut(&id) {
stream.promote_cancel_to_close_recv();
let remove = stream.try_remove();
self.remove_stream(id, remove)?;
};
Ok(())
}
fn remove_stream(&mut self, id: StreamId, remove: TryRemove) -> Result<(), StreamError> {
let res = match remove {
TryRemove::Keep => return Ok(()),
TryRemove::ResetKeep(err) => return Err(err),
TryRemove::ResetRemove(err) => Err(err),
TryRemove::Remove => Ok(()),
};
self.stream_map.remove(&id);
res
}
#[cold]
#[inline(never)]
fn try_push_reset(&mut self, id: StreamId, reason: Reason) -> Result<(), Error> {
if !matches!(reason, Reason::INTERNAL_ERROR | Reason::NO_ERROR) {
self.try_tick_reset()?;
}
self.queue.push(Message::Reset { stream_id: id, reason });
Ok(())
}
fn handle_data(&mut self, data: Data) -> Result<(), Error> {
let id = data.stream_id();
self.check_not_idle(id)?;
let flow_len = data.flow_controlled_len();
self.recv_window_dec(flow_len)?;
let stream = self.stream_map.get_mut(&id).ok_or_else(|| {
self.queue.connection_window_update(flow_len);
Error::Reset(Reason::STREAM_CLOSED)
})?;
let end_stream = data.is_end_stream();
let data = data.into_payload();
let (conn_window, stream_window, want_remove) =
match stream.try_recv_data(&mut self.frame_buf, data, flow_len, end_stream)? {
RecvData::Queued(size) => {
(size, size, false)
}
RecvData::Discard(size) => {
let stream_window = if !end_stream { size } else { 0 };
(size, stream_window, end_stream)
}
RecvData::StreamReset(size) => (size, 0, true),
};
self.queue.connection_window_update(conn_window);
self.queue.stream_window_update(id, stream_window);
if want_remove {
self.inline_remove_stream(id)?;
}
Ok(())
}
#[cold]
#[inline(never)]
fn reset(&mut self, id: &StreamId) {
self.stream_map
.get_mut(id)
.expect(STREAM_MUST_EXIST)
.try_set_reset(StreamError::InternalError);
}
#[cold]
#[inline(never)]
fn go_away(&mut self, err: Error) -> bool {
let Error::GoAway(reason) = err else {
unreachable!("Error::Reset MUST not be handled as GO_AWAY frame")
};
if let Some(last_stream_id) = self.last_stream_id.try_go_away() {
self.queue.push(Message::GoAway { last_stream_id, reason });
}
let fatal = reason != Reason::NO_ERROR;
if fatal {
self.queue.close();
}
fatal
}
pub(super) fn try_set_pending_ping(&mut self) -> io::Result<()> {
if self.queue.keepalive_ping != KeepalivePing::Idle {
return Err(io::Error::new(io::ErrorKind::TimedOut, "h2 ping timeout"));
}
self.queue.keepalive_ping = KeepalivePing::Pending;
Ok(())
}
fn try_tick_reset(&mut self) -> Result<(), Error> {
if self.reset_counter.tick() {
Err(Error::GoAway(Reason::ENHANCE_YOUR_CALM))
} else {
Ok(())
}
}
}
#[derive(Default)]
struct RemoteSettings {
header_table_size: Option<Option<u32>>,
}
impl RemoteSettings {
#[cold]
#[inline(never)]
fn try_update(&mut self, settings: Settings) -> Result<(), Error> {
if self.header_table_size.is_some() {
return Err(Error::GoAway(Reason::ENHANCE_YOUR_CALM));
}
self.header_table_size = Some(settings.header_table_size());
Ok(())
}
fn encode(&mut self, encoder: &mut hpack::Encoder, buf: &mut BytesMut) {
if let Some(header_table_size) = self.header_table_size.take() {
if let Some(size) = header_table_size {
encoder.update_max_size(size as usize);
}
Settings::ack().encode(buf);
}
}
}
enum Error {
Reset(Reason),
GoAway(Reason),
}
impl From<ProtoError> for Error {
fn from(e: ProtoError) -> Self {
if e.is_go_away() {
Self::GoAway(e.reason())
} else {
Self::Reset(e.reason())
}
}
}
impl From<StreamError> for Error {
fn from(err: StreamError) -> Self {
Self::Reset(err.reason())
}
}
pub(super) enum Message {
Head(headers::Headers<ResponsePseudo>),
Data(Data),
Trailer(headers::Headers<()>),
Reset { stream_id: StreamId, reason: Reason },
WindowUpdate { stream_id: StreamId, size: usize },
GoAway { last_stream_id: StreamId, reason: Reason },
}
pub(super) struct WriterQueue {
pub(super) messages: VecDeque<Message>,
closed: bool,
pending_settings: RemoteSettings,
pending_conn_window: usize,
keepalive_ping: KeepalivePing,
pending_client_ping: Option<[u8; 8]>,
waker: Option<Waker>,
}
pub(super) type Shared = Rc<RefCell<FlowControl>>;
impl WriterQueue {
fn new() -> Self {
Self {
messages: VecDeque::new(),
closed: false,
pending_settings: RemoteSettings::default(),
pending_conn_window: 0,
keepalive_ping: KeepalivePing::Idle,
pending_client_ping: None,
waker: None,
}
}
fn push(&mut self, msg: Message) {
self.messages.push_back(msg);
}
pub(super) fn connection_window_update(&mut self, size: usize) {
self.pending_conn_window += size;
}
fn stream_window_update(&mut self, id: StreamId, size: usize) {
if size > 0 {
self.push(Message::WindowUpdate { stream_id: id, size })
}
}
pub(super) fn push_data(&mut self, id: StreamId, payload: Bytes, end_stream: bool) {
let mut data = Data::new(id, payload);
data.set_end_stream(end_stream);
self.push(Message::Data(data));
}
pub(super) fn push_trailers(&mut self, id: StreamId, trailers: HeaderMap) {
let trailer = headers::Headers::trailers(id, trailers);
self.push(Message::Trailer(trailer));
}
fn push_end_stream(&mut self, stream_id: StreamId) {
for msg in self.messages.iter_mut().rev() {
if let Message::Data(d) = msg {
if d.stream_id() == stream_id {
d.set_end_stream(true);
return;
}
}
}
self.push_data(stream_id, Bytes::new(), true);
}
fn close(&mut self) {
self.closed = true;
}
fn try_recv(&mut self) -> Option<Message> {
self.messages.pop_front()
}
fn is_closed(&self) -> bool {
self.closed
}
fn register(&mut self, cx: &mut Context<'_>) {
if self
.waker
.as_ref()
.filter(|waker| waker.will_wake(cx.waker()))
.is_none()
{
self.waker = Some(cx.waker().clone());
}
}
fn wake(&mut self) {
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
}
type Decoded = (Request<RequestExt<RequestBody>>, StreamId);
impl<'a, S> DecodeContext<'a, S> {
fn new(
ctx: &'a Shared,
service: &'a S,
max_concurrent_streams: usize,
recv_threshold: StreamRecvWindowThreshold,
addr: SocketAddr,
date: &'a DateTimeHandle,
) -> Self {
Self {
max_header_list_size: settings::DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE,
max_concurrent_streams,
recv_threshold,
decoder: hpack::Decoder::new(settings::DEFAULT_SETTINGS_HEADER_TABLE_SIZE),
next_frame_len: 0,
continuation: None,
ctx,
service,
addr,
date,
}
}
fn try_decode(&mut self, buf: &mut BytesMut) -> Result<Option<Decoded>, Error> {
loop {
if self.next_frame_len == 0 {
if buf.len() < 3 {
return Ok(None);
}
let payload_len = buf.get_uint(3) as usize;
if payload_len > settings::DEFAULT_MAX_FRAME_SIZE as usize {
return Err(Error::GoAway(Reason::FRAME_SIZE_ERROR));
}
self.next_frame_len = payload_len + 6;
}
if buf.len() < self.next_frame_len {
return Ok(None);
}
let len = mem::replace(&mut self.next_frame_len, 0);
let mut frame = buf.split_to(len);
let head = head::Head::parse(&frame);
frame.advance(6);
if let Some(decoded) = self.decode_frame(head, frame)? {
return Ok(Some(decoded));
}
}
}
fn decode_frame(&mut self, head: head::Head, frame: BytesMut) -> Result<Option<Decoded>, Error> {
match self._decode_frame(head, frame) {
Err(Error::Reset(reason)) => {
self.ctx.borrow_mut().try_push_reset(head.stream_id(), reason)?;
Ok(None)
}
res => res,
}
}
fn _decode_frame(&mut self, head: head::Head, frame: BytesMut) -> Result<Option<Decoded>, Error> {
match (head.kind(), &self.continuation) {
(head::Kind::Continuation, _) => return self.handle_continuation(head, frame),
(_, Some(_)) => return Err(Error::GoAway(Reason::PROTOCOL_ERROR)),
(head::Kind::Headers, _) => {
let (headers, payload) = headers::Headers::load(head, frame)?;
let is_end_headers = headers.is_end_headers();
return self.handle_headers(headers, payload, is_end_headers);
}
(head::Kind::Data, _) => {
let data = Data::load(head, frame.freeze())?;
self.ctx.borrow_mut().handle_data(data)?;
}
(head::Kind::WindowUpdate, _) => {
let window = WindowUpdate::load(head, frame.as_ref())?;
let flow = &mut self.ctx.borrow_mut();
let id = window.stream_id();
flow.check_not_idle(id)?;
match (window.size_increment(), id) {
(0, StreamId::ZERO) => return Err(Error::GoAway(Reason::PROTOCOL_ERROR)),
(0, id) => {
if let Some(state) = flow.stream_map.get_mut(&id) {
state.try_set_reset(StreamError::WindowUpdateZeroIncrement);
}
}
(incr, StreamId::ZERO) => {
let incr = incr as usize;
if flow.send_connection_window + incr > settings::MAX_INITIAL_WINDOW_SIZE {
return Err(Error::GoAway(Reason::FLOW_CONTROL_ERROR));
}
let was_zero = flow.send_connection_window == 0;
flow.send_connection_window += incr;
if was_zero {
flow.update_and_wake_send_streams(0);
}
}
(incr, id) => {
let incr = incr as i64;
let window = flow.send_connection_window;
if let Some(state) = flow.stream_map.get_mut(&id) {
if state.send.window + incr > settings::MAX_INITIAL_WINDOW_SIZE as i64 {
state.try_set_reset(StreamError::WindowUpdateOverflow);
} else {
state.send.window += incr;
if window > 0 {
state.send.wake();
}
}
}
}
}
}
(head::Kind::Ping, _) => {
let ping = Ping::load(head, frame.as_ref())?;
if ping.is_ack {
self.ctx.borrow_mut().queue.keepalive_ping = KeepalivePing::Idle;
} else {
self.ctx.borrow_mut().queue.pending_client_ping = Some(ping.payload);
}
}
(head::Kind::Reset, _) => {
let reset = Reset::load(head, frame.as_ref())?;
let id = reset.stream_id();
if id.is_zero() {
return Err(Error::GoAway(Reason::PROTOCOL_ERROR));
}
let mut inner = self.ctx.borrow_mut();
inner.check_not_idle(id)?;
inner.try_reset_stream(id)?;
}
(head::Kind::GoAway, _) => {
let go_away = GoAway::load(head.stream_id(), frame.as_ref())?;
return Err(Error::GoAway(go_away.reason()));
}
(head::Kind::Priority, _) => handle_priority(head.stream_id(), &frame)?,
(head::Kind::PushPromise, _) => return Err(Error::GoAway(Reason::PROTOCOL_ERROR)),
(head::Kind::Settings, _) => self.handle_settings(head, &frame)?,
(head::Kind::Unknown, _) => {}
}
Ok(None)
}
#[cold]
#[inline(never)]
fn handle_continuation(&mut self, head: head::Head, frame: BytesMut) -> Result<Option<Decoded>, Error> {
let is_end_headers = (head.flag() & 0x4) == 0x4;
let (headers, mut payload) = self.continuation.take().ok_or(Error::GoAway(Reason::PROTOCOL_ERROR))?;
if headers.stream_id() != head.stream_id() {
return Err(Error::GoAway(Reason::PROTOCOL_ERROR));
}
payload.unsplit(frame);
self.handle_headers(headers, payload, is_end_headers)
}
#[cold]
#[inline(never)]
fn handle_settings(&mut self, head: head::Head, frame: &BytesMut) -> Result<(), Error> {
let setting = Settings::load(head, frame)?;
if setting.is_ack() {
return Ok(());
}
let mut flow = self.ctx.borrow_mut();
if let Some(new_window) = setting.initial_window_size() {
let new_window = new_window as i64;
let delta = new_window - flow.send_stream_initial_window;
flow.send_stream_initial_window = new_window;
if delta > 0 {
let overflow = flow
.stream_map
.values()
.any(|s| s.send.window + delta > settings::MAX_INITIAL_WINDOW_SIZE as i64);
if overflow {
return Err(Error::GoAway(Reason::FLOW_CONTROL_ERROR));
}
}
if delta != 0 {
flow.update_and_wake_send_streams(delta);
}
}
if let Some(frame_size) = setting.max_frame_size() {
let frame_size = frame_size as usize;
flow.max_frame_size = frame_size;
for state in flow.stream_map.values_mut() {
state.send.frame_size = frame_size;
}
}
flow.queue.pending_settings.try_update(setting)
}
fn handle_headers(
&mut self,
mut headers: headers::Headers,
mut payload: BytesMut,
is_end_headers: bool,
) -> Result<Option<Decoded>, Error> {
if let Err(e) = headers.load_hpack(&mut payload, self.max_header_list_size, &mut self.decoder) {
return match e {
ProtoError::Hpack(hpack::DecoderError::NeedMore(_)) if !is_end_headers => {
self.continuation = Some((headers, payload));
Ok(None)
}
ProtoError::MalformedMessage => {
let id = headers.stream_id();
if self.ctx.borrow_mut().last_stream_id.try_set(id)?.is_none() {
return Ok(None);
}
Err(Error::Reset(Reason::PROTOCOL_ERROR))
}
_ => Err(Error::GoAway(Reason::COMPRESSION_ERROR)),
};
}
if !is_end_headers {
self.continuation = Some((headers, payload));
return Ok(None);
}
let id = headers.stream_id();
self.handle_header_frame(id, headers)
}
fn handle_header_frame(&mut self, id: StreamId, headers: headers::Headers) -> Result<Option<Decoded>, Error> {
let end_stream = headers.is_end_stream();
let (pseudo, headers) = headers.into_parts();
let flow = &mut *self.ctx.borrow_mut();
if !flow.last_stream_id.check_idle(id) {
let stream = flow
.stream_map
.get_mut(&id)
.ok_or(Error::GoAway(Reason::STREAM_CLOSED))?;
match stream.try_recv_trailers(&mut flow.frame_buf, headers, end_stream)? {
RecvData::Queued(_) => {}
_ => flow.inline_remove_stream(id)?,
}
return Ok(None);
}
if flow.last_stream_id.try_set(id)?.is_none() {
return Ok(None);
}
if flow.stream_map.len() >= self.max_concurrent_streams {
return Err(Error::Reset(Reason::REFUSED_STREAM));
}
let content_length =
BodySize::from_header(&headers, end_stream).map_err(|_| Error::Reset(Reason::PROTOCOL_ERROR))?;
let method = pseudo.method.ok_or(Error::Reset(Reason::PROTOCOL_ERROR))?;
let protocol = pseudo.protocol.map(|proto| Protocol::from_str(&proto));
let is_strict_connect = method == Method::CONNECT && protocol.is_none();
let mut uri_parts = uri::Parts::default();
if let Some(authority) = pseudo.authority {
if let Ok(a) = uri::Authority::from_maybe_shared(authority.into_inner()) {
uri_parts.authority = Some(a);
}
}
match (is_strict_connect, pseudo.scheme) {
(true, Some(_)) | (false, None) => return Err(Error::Reset(Reason::PROTOCOL_ERROR)),
(false, Some(scheme)) if uri_parts.authority.is_some() => {
if let Ok(s) = uri::Scheme::try_from(scheme.as_str()) {
uri_parts.scheme = Some(s);
}
}
_ => {}
}
match (is_strict_connect, pseudo.path) {
(true, Some(_)) | (false, None) => return Err(Error::Reset(Reason::PROTOCOL_ERROR)),
(_, Some(path)) if !path.is_empty() => {
if let Ok(pq) = uri::PathAndQuery::from_maybe_shared(path.into_inner()) {
uri_parts.path_and_query = Some(pq);
}
}
(false, _) => return Err(Error::Reset(Reason::PROTOCOL_ERROR)), _ => {}
}
let ext = Extension::with_protocol(self.addr, protocol);
let mut req = Request::new(RequestExt::from_parts((), ext));
*req.version_mut() = Version::HTTP_2;
*req.headers_mut() = headers;
*req.method_mut() = method;
if let Ok(uri) = Uri::from_parts(uri_parts) {
*req.uri_mut() = uri;
}
let body = RequestBody::new(id, content_length, Rc::clone(self.ctx), self.recv_threshold);
flow.insert_stream(id, end_stream, content_length);
let req = req.map(|ext| ext.map_body(|_| body));
Ok(Some((req, id)))
}
}
async fn read_io<const LIMIT: usize>(mut buf: BytesMut, io: &impl AsyncBufRead) -> (io::Result<usize>, BytesMut) {
if buf.len() >= LIMIT {
return core::future::pending().await;
}
let len = buf.len();
buf.reserve(4096);
let (res, buf) = io.read(buf.slice(len..)).await;
(res, buf.into_inner())
}
async fn write_io(buf: BytesMut, io: &impl AsyncBufWrite) -> (io::Result<()>, BytesMut) {
let (res, mut buf) = write_all(io, buf).await;
buf.clear();
(res, buf)
}
struct EncodeContext<'a> {
encoder: hpack::Encoder,
ctx: &'a Shared,
}
impl<'a> EncodeContext<'a> {
fn new(ctx: &'a Shared) -> Self {
Self {
encoder: hpack::Encoder::new(65535, 4096),
ctx,
}
}
fn poll_encode(&mut self, write_buf: &mut BytesMut, cx: &mut Context<'_>) -> Poll<bool> {
let mut flow = self.ctx.borrow_mut();
flow.queue.pending_settings.encode(&mut self.encoder, write_buf);
while let Some(msg) = flow.queue.try_recv() {
match msg {
Message::Head(headers) => {
let frame_size = flow.max_frame_size;
let mut cont = headers.encode(&mut self.encoder, &mut write_buf.limit(frame_size));
while let Some(c) = cont {
cont = c.encode(&mut write_buf.limit(frame_size));
}
}
Message::Trailer(headers) => {
let frame_size = flow.max_frame_size;
let mut cont = headers.encode(&mut self.encoder, &mut write_buf.limit(frame_size));
while let Some(c) = cont {
cont = c.encode(&mut write_buf.limit(frame_size));
}
}
Message::Data(mut data) => data.encode_chunk(write_buf),
Message::Reset { stream_id, reason } => Reset::new(stream_id, reason).encode(write_buf),
Message::WindowUpdate { stream_id, size } => WindowUpdate::new(stream_id, size as _).encode(write_buf),
Message::GoAway { last_stream_id, reason } => {
GoAway::new(last_stream_id, reason).encode(write_buf);
}
}
}
let pending = mem::replace(&mut flow.queue.pending_conn_window, 0);
if pending > 0 {
flow.recv_connection_window += pending;
WindowUpdate::new(StreamId::zero(), pending as _).encode(write_buf);
}
if let Some(payload) = flow.queue.pending_client_ping.take() {
head::Head::new(head::Kind::Ping, 0x1, StreamId::zero()).encode(8, write_buf);
write_buf.put_slice(&payload);
}
if flow.queue.keepalive_ping == KeepalivePing::Pending {
head::Head::new(head::Kind::Ping, 0x0, StreamId::zero()).encode(8, write_buf);
write_buf.put_slice(&[0u8; 8]);
flow.queue.keepalive_ping = KeepalivePing::InFlight;
}
if !write_buf.is_empty() {
Poll::Ready(true)
} else if flow.queue.is_closed() {
Poll::Ready(false)
} else {
flow.queue.register(cx);
Poll::Pending
}
}
}
async fn response_task<S, ReqB, ResB, ResBE>(
req: Request<RequestExt<RequestBody>>,
stream_id: StreamId,
service: &S,
ctx: &Shared,
date: &DateTimeHandle,
) -> Result<(), ()>
where
S: Service<Request<RequestExt<ReqB>>, Response = Response<ResB>>,
S::Error: fmt::Debug,
ReqB: From<RequestBody>,
ResB: Body<Data = Bytes, Error = ResBE>,
ResBE: fmt::Debug,
{
_response_task(req, stream_id, service, ctx, date)
.await
.unwrap_or_else(|_| {
let mut flow = ctx.borrow_mut();
flow.reset(&stream_id);
flow
})
.response_task_done(stream_id)
}
#[allow(clippy::await_holding_refcell_ref)]
async fn _response_task<'a, S, ReqB, ResB, ResBE>(
req: Request<RequestExt<RequestBody>>,
stream_id: StreamId,
service: &S,
ctx: &'a Shared,
date: &DateTimeHandle,
) -> Result<RefMut<'a, FlowControl>, ()>
where
S: Service<Request<RequestExt<ReqB>>, Response = Response<ResB>>,
S::Error: fmt::Debug,
ReqB: From<RequestBody>,
ResB: Body<Data = Bytes, Error = ResBE>,
ResBE: fmt::Debug,
{
let req = req.map(|ext| ext.map_body(From::from));
let head_method = req.method() == Method::HEAD;
let res = service.call(req).await.map_err(|_| ())?;
let (mut parts, body) = res.into_parts();
super::strip_connection_headers::<false>(&mut parts.headers);
if !parts.headers.contains_key(DATE) {
let date = date.with_date_header(Clone::clone);
parts.headers.insert(DATE, date);
}
let end_stream = match (head_method, body.size_hint()) {
(true, _) => true,
(false, SizeHint::None) => true,
(false, size) => {
if let SizeHint::Exact(size) = size {
parts.headers.entry(CONTENT_LENGTH).or_insert_with(|| size.into());
}
false
}
};
let pseudo = headers::Pseudo::response(parts.status);
let mut headers = headers::Headers::new(stream_id, pseudo, parts.headers);
if end_stream {
headers.set_end_stream();
}
let mut flow = ctx.borrow_mut();
flow.queue.push(Message::Head(headers));
if !end_stream {
drop(flow);
let mut body = pin!(body);
flow = loop {
match poll_fn(|cx| body.as_mut().poll_frame(cx)).await {
None => {
let mut flow = ctx.borrow_mut();
flow.queue.push_end_stream(stream_id);
break flow;
}
Some(Err(e)) => {
error!("body error: {e:?}");
return Err(());
}
Some(Ok(Frame::Data(bytes))) => {
if let Some(flow) = send_data(stream_id, bytes, body.is_end_stream(), ctx).await {
break flow;
}
}
Some(Ok(Frame::Trailers(trailers))) => {
let mut flow = ctx.borrow_mut();
flow.queue.push_trailers(stream_id, trailers);
break flow;
}
}
}
}
Ok(flow)
}
async fn send_data(
stream_id: StreamId,
data: Bytes,
end_stream: bool,
ctx: &Shared,
) -> Option<RefMut<'_, FlowControl>> {
if data.is_empty() && !end_stream {
tracing::warn!("response body should not yield empty Frame::Data unless it's the last chunk of Body");
return None;
}
SendData {
data,
end_stream,
stream_id,
flow: ctx,
}
.await
}
struct SendData<'a> {
data: Bytes,
end_stream: bool,
stream_id: StreamId,
flow: &'a Shared,
}
impl<'a> Future for SendData<'a> {
type Output = Option<RefMut<'a, FlowControl>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
let mut flow = this.flow.borrow_mut();
loop {
let len = this.data.len();
let window = flow.send_connection_window;
let opt = ready!(
flow.stream_map
.get_mut(&this.stream_id)
.expect(STREAM_MUST_EXIST)
.poll_send_window(len, window, cx)
);
let Some(Ok(aval)) = opt else {
return Poll::Ready(Some(flow));
};
flow.send_connection_window -= aval;
let all_consumed = aval == len;
let payload = if all_consumed {
mem::take(&mut this.data)
} else {
this.data.split_to(aval)
};
let end_stream = all_consumed && this.end_stream;
flow.queue.push_data(this.stream_id, payload, end_stream);
if end_stream {
return Poll::Ready(Some(flow));
} else if all_consumed {
return Poll::Ready(None);
}
}
}
}
const RESET_MAX: usize = 20;
const RESET_WINDOW: Duration = Duration::from_secs(30);
pub(crate) async fn peek_version(
io: &(impl AsyncBufRead + AsyncBufWrite),
buf: BytesMut,
) -> io::Result<(Version, BytesMut)> {
let (read_buf, res) = prefix_check::<4096>(buf, io).await;
let version = if res.is_ok() { Version::HTTP_2 } else { Version::HTTP_11 };
Ok((version, read_buf))
}
pub(crate) async fn run<
Io,
S,
ReqB,
ResB,
ResBE,
const HEADER_LIMIT: usize,
const READ_BUF_LIMIT: usize,
const WRITE_BUF_LIMIT: usize,
>(
io: Io,
addr: SocketAddr,
read_buf: BytesMut,
mut ka: Pin<&mut KeepAlive>,
service: &S,
date: &DateTimeHandle,
config: &HttpServiceConfig<HEADER_LIMIT, READ_BUF_LIMIT, WRITE_BUF_LIMIT>,
) -> io::Result<()>
where
Io: AsyncBufRead + AsyncBufWrite,
S: Service<Request<RequestExt<ReqB>>, Response = Response<ResB>>,
ReqB: From<RequestBody>,
S::Error: fmt::Debug,
ResB: Body<Data = Bytes, Error = ResBE>,
ResBE: fmt::Debug,
{
let mut settings = settings::Settings::default();
settings.set_max_concurrent_streams(Some(config.h2_max_concurrent_streams));
settings.set_initial_window_size(Some(config.h2_initial_window_size));
settings.set_max_frame_size(Some(config.h2_max_frame_size));
settings.set_max_header_list_size(Some(config.h2_max_header_list_size));
settings.set_enable_connect_protocol(Some(1));
let max_concurrent_streams = config.h2_max_concurrent_streams as usize;
let mut flow = FlowControl {
send_connection_window: settings::DEFAULT_INITIAL_WINDOW_SIZE as usize,
send_stream_initial_window: settings::DEFAULT_INITIAL_WINDOW_SIZE as i64,
max_frame_size: settings::DEFAULT_MAX_FRAME_SIZE as usize,
recv_stream_initial_window: config.h2_initial_window_size as usize,
recv_connection_window: settings::DEFAULT_INITIAL_WINDOW_SIZE as usize,
stream_map: HashMap::with_capacity_and_hasher(max_concurrent_streams, NoHashBuilder::default()),
reset_counter: ResetCounter::new(RESET_MAX, RESET_WINDOW),
last_stream_id: LastStreamId::new(),
queue: WriterQueue::new(),
frame_buf: FrameBuffer::new(),
};
let (mut read_buf, mut write_buf) = flow
.handshake::<READ_BUF_LIMIT>(&io, read_buf, &settings, ka.as_mut())
.await?;
let shared = Rc::new(RefCell::new(flow));
let recv_threshold = StreamRecvWindowThreshold::from(&settings);
let mut ctx = DecodeContext::new(&shared, service, max_concurrent_streams, recv_threshold, addr, date);
let mut enc = EncodeContext::new(&shared);
let mut queue = Queue::new();
let mut ping_pong = PingPong::new(ka.as_mut(), &shared, date, config.keep_alive_timeout);
let res = {
let mut read_task = pin!(read_io::<READ_BUF_LIMIT>(read_buf, &io));
let mut write_task = pin!(async {
while poll_fn(|cx| enc.poll_encode(&mut write_buf, cx)).await {
let (res, buf) = io.write(write_buf).await;
write_buf = buf;
match res {
Ok(0) => return Err(io::ErrorKind::WriteZero.into()),
Ok(n) => write_buf.advance(n),
Err(e) => return Err(e),
}
}
Ok(())
});
let shutdown = 'body: loop {
match read_task
.as_mut()
.select(async {
loop {
let res: Result<(), ()> = queue.next().await;
if res.is_err() {
break;
}
}
})
.select(write_task.as_mut())
.select(ping_pong.tick())
.await
{
SelectOutput::A(SelectOutput::A(SelectOutput::A((res, buf)))) => {
read_buf = buf;
match res {
Ok(n) if n > 0 => loop {
match ctx.try_decode(&mut read_buf) {
Ok(Some((req, id))) => {
queue.push(response_task(req, id, ctx.service, ctx.ctx, ctx.date));
}
Ok(None) => break,
Err(err) => {
if ctx.ctx.borrow_mut().go_away(err) {
break 'body ShutDown::ReadClosed(Ok(()));
}
}
}
},
res => break ShutDown::ReadClosed(res.map(|_| ())),
};
read_task.set(read_io(read_buf, &io));
}
SelectOutput::A(SelectOutput::A(SelectOutput::B(_))) => break ShutDown::ReadClosed(Ok(())),
SelectOutput::A(SelectOutput::B(res)) => break ShutDown::WriteClosed(res),
SelectOutput::B(Err(e)) => break ShutDown::Timeout(e),
SelectOutput::B(Ok(_)) => {}
}
};
Box::pin(async {
let (io_res, want_write) = match shutdown {
ShutDown::WriteClosed(res) => (res, false),
ShutDown::Timeout(err) => return Err(err),
ShutDown::ReadClosed(res) => (res, true),
};
let stream_err = if io_res.is_ok() {
StreamError::GoAway
} else {
StreamError::Io
};
for stream in ctx.ctx.borrow_mut().stream_map.values_mut() {
stream.try_set_reset(stream_err);
}
loop {
if queue.is_empty() {
shared.borrow_mut().queue.close();
if !want_write {
break io_res;
}
}
match queue
.next()
.select(async {
if want_write {
write_task.as_mut().await
} else {
core::future::pending().await
}
})
.select(ping_pong.tick())
.await
{
SelectOutput::A(SelectOutput::A(_)) => {
}
SelectOutput::A(SelectOutput::B(res)) => {
res?;
break io_res;
}
SelectOutput::B(res) => res?,
}
}
})
.await
};
lingering_read(&io, ka, date).await?;
let _ = io.shutdown(Shutdown::Write).await;
res
}
enum ShutDown {
ReadClosed(io::Result<()>),
WriteClosed(io::Result<()>),
Timeout(io::Error),
}
#[cold]
#[inline(never)]
fn handle_priority(id: StreamId, payload: &[u8]) -> Result<(), Error> {
if id.is_zero() {
Err(Error::GoAway(Reason::PROTOCOL_ERROR))
} else if payload.len() != 5 {
Err(Error::Reset(Reason::FRAME_SIZE_ERROR))
} else if id == StreamId::parse(&payload[..4]).0 {
Err(Error::Reset(Reason::PROTOCOL_ERROR))
} else {
Ok(())
}
}
type BoxedFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;
impl FlowControl {
#[cold]
#[inline(never)]
fn handshake<'a, const LIMIT: usize>(
&'a mut self,
io: &'a (impl AsyncBufRead + AsyncBufWrite),
buf: BytesMut,
settings: &'a settings::Settings,
timer: Pin<&'a mut KeepAlive>,
) -> BoxedFuture<'a, io::Result<(BytesMut, BytesMut)>> {
Box::pin(async move {
async {
let (mut read_buf, res) = prefix_check::<LIMIT>(buf, io).await;
res?;
read_buf.advance(PREFACE.len());
let mut write_buf = BytesMut::new();
settings.encode(&mut write_buf);
let delta =
(self.recv_stream_initial_window as u32).saturating_sub(settings::DEFAULT_INITIAL_WINDOW_SIZE);
if delta > 0 {
WindowUpdate::new(StreamId::ZERO, delta).encode(&mut write_buf);
self.recv_connection_window += delta as usize;
};
let (res, write_buf) = write_io(write_buf, io).await;
res?;
Ok((read_buf, write_buf))
}
.timeout(timer)
.await
.map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "h2 handshake timeout"))?
})
}
}
async fn prefix_check<const LIMIT: usize>(
mut read_buf: BytesMut,
io: &(impl AsyncBufRead + AsyncBufWrite),
) -> (BytesMut, io::Result<()>) {
while read_buf.len() < PREFACE.len() {
let (res, b) = read_io::<LIMIT>(read_buf, io).await;
read_buf = b;
if res.is_err() {
return (read_buf, res.map(|_| ()));
};
}
let res = if !read_buf.starts_with(PREFACE) {
Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid HTTP/2 client preface",
))
} else {
Ok(())
};
(read_buf, res)
}
#[cold]
#[inline(never)]
async fn lingering_read(io: &impl AsyncBufRead, mut ka: Pin<&mut KeepAlive>, date: &DateTimeHandle) -> io::Result<()> {
ka.as_mut().update(date.now() + Duration::from_secs(5));
ka.as_mut().reset();
let mut read_buf = BytesMut::with_capacity(4096);
loop {
read_buf.clear();
match io.read(read_buf).timeout(ka.as_mut()).await {
Ok((res, buf)) => {
read_buf = buf;
if res? == 0 {
return Ok(());
}
}
Err(_) => return Ok(()),
}
}
}