use core::{
cell::RefCell,
fmt,
future::poll_fn,
mem,
pin::{Pin, pin},
task::Poll,
};
use std::{
collections::{HashMap, VecDeque},
io,
net::{Shutdown, SocketAddr},
rc::Rc,
time::Duration,
};
use tracing::error;
use xitca_io::{
bytes::{Buf, BufMut, BytesMut},
io::{AsyncBufRead, AsyncBufWrite, BoundedBuf, write_all},
};
use xitca_service::Service;
use xitca_unsafe_collection::{
futures::{Select, SelectOutput},
no_hash::NoHashBuilder,
};
use crate::{
body::{Body, SizeHint},
bytes::Bytes,
config::HttpServiceConfig,
date::{DateTime, DateTimeHandle},
http::{
Extension, HeaderMap, Method, Protocol, Request, RequestExt, Response, Uri, Version,
header::{CONTENT_LENGTH, DATE},
uri,
},
util::{
futures::Queue,
timer::{KeepAlive, Timeout},
},
};
use super::{
body::RequestBody,
proto::{
error::Error,
frame::{
PREFACE, data,
go_away::GoAway,
head,
headers::{self, ResponsePseudo},
ping::Ping,
reason::Reason,
reset::Reset,
settings::{self, Settings},
stream_id::StreamId,
window_update::WindowUpdate,
},
hpack,
size::BodySize,
stream::{RecvClose, RecvData, Remove, Stream, StreamError},
threshold::StreamRecvWindowThreshold,
},
};
struct DecodeContext<'a, S> {
max_header_list_size: usize,
max_concurrent_streams: usize,
recv_threshold: StreamRecvWindowThreshold,
decoder: hpack::Decoder,
next_frame_len: usize,
continuation: Option<(headers::Headers, BytesMut)>,
service: &'a S,
ctx: &'a Shared,
addr: SocketAddr,
date: &'a DateTimeHandle,
}
pub(super) type Frame = crate::body::Frame<Bytes>;
pub(super) type FrameBuffer = super::util::FrameBuffer<Frame>;
#[derive(PartialEq)]
enum KeepalivePing {
Idle,
Pending,
InFlight,
}
#[derive(Clone, Copy)]
pub(super) enum LastStreamId {
Incrementable(StreamId),
Saturated(StreamId),
}
impl LastStreamId {
pub(super) fn get(self) -> StreamId {
match self {
Self::Incrementable(id) | Self::Saturated(id) => id,
}
}
pub(super) fn is_saturated(self) -> bool {
matches!(self, Self::Saturated(_))
}
pub(super) fn try_set(&mut self, id: StreamId) -> Result<Option<()>, Error> {
match self {
Self::Saturated(_) => Ok(None),
Self::Incrementable(last_id) if !id.is_client_initiated() || id <= *last_id => {
Err(Error::GoAway(Reason::PROTOCOL_ERROR))
}
Self::Incrementable(last_id) => {
*last_id = id;
Ok(Some(()))
}
}
}
pub(super) fn try_set_saturate(&mut self) -> Option<StreamId> {
match *self {
Self::Incrementable(id) => {
let _ = mem::replace(self, LastStreamId::Saturated(id));
Some(id)
}
_ => None,
}
}
}
pub(super) struct FlowControl {
send_connection_window: usize,
stream_window: i64,
max_frame_size: usize,
recv_connection_window: usize,
pub(super) stream_map: HashMap<StreamId, Stream, NoHashBuilder>,
premature_reset_count: usize,
last_stream_id: LastStreamId,
pub(super) queue: WriterQueue,
pub(super) frame_buf: FrameBuffer,
}
impl FlowControl {
fn insert_stream(&mut self, id: StreamId, end_stream: bool, content_length: SizeHint) {
let stream = Stream::new(self.stream_window, self.max_frame_size, content_length, end_stream);
self.stream_map.insert(id, stream);
}
fn check_not_idle(&self, id: StreamId) -> Result<(), Error> {
if id > self.last_stream_id.get() {
Err(Error::GoAway(Reason::PROTOCOL_ERROR))
} else {
Ok(())
}
}
fn recv_window_dec(&mut self, len: usize) -> Result<(), Error> {
self.recv_connection_window = self
.recv_connection_window
.checked_sub(len)
.ok_or(Error::GoAway(Reason::FLOW_CONTROL_ERROR))?;
Ok(())
}
fn update_and_wake_send_streams(&mut self, delta: i64) {
for state in self.stream_map.values_mut() {
let sf = &mut state.send;
sf.window += delta;
if sf.window > 0 {
sf.wake();
}
}
}
fn try_reset_stream(&mut self, id: StreamId) -> Result<(), Error> {
let Some(state) = self.stream_map.get_mut(&id) else {
return Ok(());
};
state.try_set_peer_reset();
self.premature_reset_count += 1;
if self.premature_reset_count > PREMATURE_RESET_LIMIT {
return Err(Error::GoAway(Reason::ENHANCE_YOUR_CALM));
}
Ok(())
}
pub(super) fn request_body_drop(&mut self, id: StreamId) {
self.try_remove_stream(id);
}
fn stream_guard_drop(&mut self, id: StreamId) {
let stream = self.stream_map.get_mut(&id).expect(STREAM_MUST_EXIST);
stream.send.set_close();
if let Some(remove) = stream.try_remove() {
self.remove_stream(id, remove);
}
}
fn try_remove_stream(&mut self, id: StreamId) {
if let Some(stream) = self.stream_map.get_mut(&id) {
match stream.maybe_close_recv(&mut self.frame_buf) {
RecvClose::Cancel(size) => {
self.queue.connection_window_update(size);
if size > 0 {
self.queue.push(Message::WindowUpdate { stream_id: id, size })
}
}
RecvClose::Close(size) => {
self.queue.connection_window_update(size);
}
}
if let Some(remove) = stream.try_remove() {
self.remove_stream(id, remove);
}
}
}
fn remove_stream(&mut self, id: StreamId, remove: Remove) {
self.stream_map.remove(&id);
match remove {
Remove::Reset(reason) => self.queue.push(Message::Reset { stream_id: id, reason }),
Remove::Graceful => self.premature_reset_count = self.premature_reset_count.saturating_sub(1),
}
}
fn handle_data(&mut self, id: StreamId, payload: Bytes, end_stream: bool) -> Result<(), Error> {
let len = payload.len();
self.recv_window_dec(len)?;
let stream = self.stream_map.get_mut(&id).ok_or_else(|| {
self.queue.connection_window_update(len);
Error::Reset(Reason::STREAM_CLOSED)
})?;
match stream.try_recv_data(&mut self.frame_buf, payload, end_stream)? {
RecvData::Queued => {}
RecvData::Discard(size) => {
self.queue.connection_window_update(size);
if end_stream {
self.try_remove_stream(id);
} else {
self.queue
.messages
.push_back(Message::WindowUpdate { stream_id: id, size });
}
}
RecvData::StreamReset(size) => {
self.queue.connection_window_update(size);
}
}
Ok(())
}
}
#[derive(Default)]
struct RemoteSettings {
header_table_size: Option<Option<u32>>,
}
impl RemoteSettings {
fn try_update(&mut self, settings: Settings) -> Result<(), Error> {
if self.header_table_size.is_some() {
return Err(Error::GoAway(Reason::ENHANCE_YOUR_CALM));
}
self.header_table_size = Some(settings.header_table_size());
Ok(())
}
fn encode(&mut self, encoder: &mut hpack::Encoder, buf: &mut BytesMut) {
if let Some(header_table_size) = self.header_table_size.take() {
if let Some(size) = header_table_size {
encoder.update_max_size(size as usize);
}
Settings::ack().encode(buf);
}
}
}
pub(super) enum Message {
Head(headers::Headers<ResponsePseudo>),
Data(data::Data),
Trailer(headers::Headers<()>),
Reset { stream_id: StreamId, reason: Reason },
WindowUpdate { stream_id: StreamId, size: usize },
GoAway { last_stream_id: StreamId, reason: Reason },
}
pub(super) struct WriterQueue {
pub(super) messages: VecDeque<Message>,
resets: VecDeque<(StreamId, Reason)>,
closed: bool,
pending_settings: RemoteSettings,
pending_conn_window: usize,
keepalive_ping: KeepalivePing,
pending_client_ping: Option<[u8; 8]>,
}
pub(super) type Shared = Rc<RefCell<FlowControl>>;
const STREAM_MUST_EXIST: &str = "Stream MUST NOT be removed while RequestBody or StreamGuard is still alive";
impl WriterQueue {
fn new() -> Self {
Self {
messages: VecDeque::new(),
resets: VecDeque::new(),
closed: false,
pending_settings: RemoteSettings::default(),
pending_conn_window: 0,
keepalive_ping: KeepalivePing::Idle,
pending_client_ping: None,
}
}
fn push(&mut self, msg: Message) {
self.messages.push_back(msg);
}
pub(super) fn connection_window_update(&mut self, size: usize) {
self.pending_conn_window += size;
}
pub(super) fn push_data(&mut self, id: StreamId, payload: Bytes, end_stream: bool) {
let mut data = data::Data::new(id, payload);
data.set_end_stream(end_stream);
self.push(Message::Data(data));
}
pub(super) fn push_trailers(&mut self, id: StreamId, trailers: HeaderMap) {
let trailer = headers::Headers::trailers(id, trailers);
self.push(Message::Trailer(trailer));
}
fn push_client_reset(&mut self, stream_id: StreamId, reason: Reason) -> bool {
self.resets.push_back((stream_id, reason));
self.resets.len() > CLIENT_RESET_QUEUE_CAP
}
fn push_end_stream(&mut self, stream_id: StreamId) {
for msg in self.messages.iter_mut().rev() {
if let Message::Data(d) = msg {
if d.stream_id() == stream_id {
d.set_end_stream(true);
return;
}
}
}
self.push_data(stream_id, Bytes::new(), true);
}
fn close(&mut self) {
self.closed = true;
}
fn poll_recv(&mut self) -> Poll<Option<Message>> {
if let Some((stream_id, reason)) = self.resets.pop_front() {
Poll::Ready(Some(Message::Reset { stream_id, reason }))
} else if let Some(msg) = self.messages.pop_front() {
Poll::Ready(Some(msg))
} else if self.closed {
Poll::Ready(None)
} else {
Poll::Pending
}
}
}
type Decoded = (Request<RequestExt<RequestBody>>, StreamId);
impl<'a, S> DecodeContext<'a, S> {
fn handle_stream_reset(&self, id: StreamId, reason: Reason) -> Result<(), Error> {
let mut inner = self.ctx.borrow_mut();
if inner.queue.push_client_reset(id, reason) {
return Err(Error::GoAway(Reason::ENHANCE_YOUR_CALM));
}
Ok(())
}
fn new(
ctx: &'a Shared,
service: &'a S,
max_concurrent_streams: usize,
recv_threshold: StreamRecvWindowThreshold,
addr: SocketAddr,
date: &'a DateTimeHandle,
) -> Self {
Self {
max_header_list_size: settings::DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE,
max_concurrent_streams,
recv_threshold,
decoder: hpack::Decoder::new(settings::DEFAULT_SETTINGS_HEADER_TABLE_SIZE),
next_frame_len: 0,
continuation: None,
ctx,
service,
addr,
date,
}
}
fn decode(&mut self, buf: &mut BytesMut, mut on_msg: impl FnMut(&Self, Decoded)) -> Result<(), ShutDown> {
let reason = loop {
match self.try_decode(buf) {
Ok(Some(res)) => on_msg(self, res),
Ok(None) => return Ok(()),
Err(Error::Reset(_)) => unreachable!(),
Err(Error::GoAway(reason)) => break reason,
Err(Error::Hpack(_)) => break Reason::COMPRESSION_ERROR,
Err(Error::MalformedMessage) => break Reason::PROTOCOL_ERROR,
}
};
Err(self.go_away(reason))
}
fn try_decode(&mut self, buf: &mut BytesMut) -> Result<Option<Decoded>, Error> {
loop {
if self.next_frame_len == 0 {
if buf.len() < 3 {
return Ok(None);
}
let payload_len = buf.get_uint(3) as usize;
if payload_len > settings::DEFAULT_MAX_FRAME_SIZE as usize {
return Err(Error::GoAway(Reason::FRAME_SIZE_ERROR));
}
self.next_frame_len = payload_len + 6;
}
if buf.len() < self.next_frame_len {
return Ok(None);
}
let len = mem::replace(&mut self.next_frame_len, 0);
let mut frame = buf.split_to(len);
let head = head::Head::parse(&frame);
frame.advance(6);
if let Some(decoded) = self.decode_frame(head, frame)? {
return Ok(Some(decoded));
}
}
}
fn go_away(&self, reason: Reason) -> ShutDown {
let mut inner = self.ctx.borrow_mut();
if let Some(last_stream_id) = inner.last_stream_id.try_set_saturate() {
inner.queue.push(Message::GoAway { last_stream_id, reason });
inner.queue.close();
}
ShutDown::DrainWrite
}
fn decode_frame(&mut self, head: head::Head, frame: BytesMut) -> Result<Option<Decoded>, Error> {
match self._decode_frame(head, frame) {
Err(Error::Reset(reason)) => {
self.handle_stream_reset(head.stream_id(), reason)?;
Ok(None)
}
res => res,
}
}
fn _decode_frame(&mut self, head: head::Head, frame: BytesMut) -> Result<Option<Decoded>, Error> {
match (head.kind(), &self.continuation) {
(head::Kind::Continuation, _) => return self.handle_continuation(head, frame),
(_, Some(_)) => return Err(Error::GoAway(Reason::PROTOCOL_ERROR)),
(head::Kind::Headers, _) => {
let (headers, payload) = headers::Headers::load(head, frame)?;
let is_end_headers = headers.is_end_headers();
return self.handle_headers(headers, payload, is_end_headers);
}
(head::Kind::Data, _) => {
let data = data::Data::load(head, frame.freeze())?;
let is_end = data.is_end_stream();
let id = data.stream_id();
let payload = data.into_payload();
let mut state = self.ctx.borrow_mut();
state.check_not_idle(id)?;
state.handle_data(id, payload, is_end)?;
}
(head::Kind::WindowUpdate, _) => {
let window = WindowUpdate::load(head, frame.as_ref())?;
let flow = &mut self.ctx.borrow_mut();
let id = window.stream_id();
flow.check_not_idle(id)?;
match (window.size_increment(), id) {
(0, StreamId::ZERO) => return Err(Error::GoAway(Reason::PROTOCOL_ERROR)),
(0, id) => {
if let Some(state) = flow.stream_map.get_mut(&id) {
state.set_reset(StreamError::WindowUpdateZeroIncrement);
}
}
(incr, StreamId::ZERO) => {
let incr = incr as usize;
if flow.send_connection_window + incr > settings::MAX_INITIAL_WINDOW_SIZE {
return Err(Error::GoAway(Reason::FLOW_CONTROL_ERROR));
}
let was_zero = flow.send_connection_window == 0;
flow.send_connection_window += incr;
if was_zero {
flow.update_and_wake_send_streams(0);
}
}
(incr, id) => {
let incr = incr as i64;
let window = flow.send_connection_window;
if let Some(state) = flow.stream_map.get_mut(&id) {
if state.send.window + incr > settings::MAX_INITIAL_WINDOW_SIZE as i64 {
state.set_reset(StreamError::WindowUpdateOverflow);
} else {
state.send.window += incr;
if window > 0 {
state.send.wake();
}
}
}
}
}
}
(head::Kind::Ping, _) => {
let ping = Ping::load(head, frame.as_ref())?;
if ping.is_ack {
self.ctx.borrow_mut().queue.keepalive_ping = KeepalivePing::Idle;
} else {
self.ctx.borrow_mut().queue.pending_client_ping = Some(ping.payload);
}
}
(head::Kind::Reset, _) => {
let reset = Reset::load(head, frame.as_ref())?;
let id = reset.stream_id();
if id.is_zero() {
return Err(Error::GoAway(Reason::PROTOCOL_ERROR));
}
let mut inner = self.ctx.borrow_mut();
inner.check_not_idle(id)?;
inner.try_reset_stream(id)?;
}
(head::Kind::GoAway, _) => {
let go_away = GoAway::load(head.stream_id(), frame.as_ref())?;
if go_away.reason() != Reason::NO_ERROR {
tracing::warn!(
"received GOAWAY with error: {:?} last_stream={:?}",
go_away.reason(),
go_away.last_stream_id(),
);
}
self.go_away(Reason::NO_ERROR);
}
(head::Kind::Priority, _) => handle_priority(head.stream_id(), &frame)?,
(head::Kind::PushPromise, _) => return Err(Error::GoAway(Reason::PROTOCOL_ERROR)),
(head::Kind::Settings, _) => self.handle_settings(head, &frame)?,
(head::Kind::Unknown, _) => {}
}
Ok(None)
}
#[cold]
#[inline(never)]
fn handle_continuation(&mut self, head: head::Head, frame: BytesMut) -> Result<Option<Decoded>, Error> {
let is_end_headers = (head.flag() & 0x4) == 0x4;
let (headers, mut payload) = self.continuation.take().ok_or(Error::GoAway(Reason::PROTOCOL_ERROR))?;
if headers.stream_id() != head.stream_id() {
return Err(Error::GoAway(Reason::PROTOCOL_ERROR));
}
payload.extend_from_slice(&frame);
self.handle_headers(headers, payload, is_end_headers)
}
#[cold]
#[inline(never)]
fn handle_settings(&mut self, head: head::Head, frame: &BytesMut) -> Result<(), Error> {
let setting = Settings::load(head, frame)?;
if setting.is_ack() {
return Ok(());
}
let mut flow = self.ctx.borrow_mut();
if let Some(new_window) = setting.initial_window_size() {
let new_window = new_window as i64;
let delta = new_window - flow.stream_window;
flow.stream_window = new_window;
if delta > 0 {
let overflow = flow
.stream_map
.values()
.any(|s| s.send.window + delta > settings::MAX_INITIAL_WINDOW_SIZE as i64);
if overflow {
return Err(Error::GoAway(Reason::FLOW_CONTROL_ERROR));
}
}
if delta != 0 {
flow.update_and_wake_send_streams(delta);
}
}
if let Some(frame_size) = setting.max_frame_size() {
let frame_size = frame_size as usize;
flow.max_frame_size = frame_size;
for state in flow.stream_map.values_mut() {
state.send.frame_size = frame_size;
}
}
flow.queue.pending_settings.try_update(setting)
}
fn handle_headers(
&mut self,
mut headers: headers::Headers,
mut payload: BytesMut,
is_end_headers: bool,
) -> Result<Option<Decoded>, Error> {
if let Err(e) = headers.load_hpack(&mut payload, self.max_header_list_size, &mut self.decoder) {
return match e {
Error::Hpack(hpack::DecoderError::NeedMore(_)) if !is_end_headers => {
self.continuation = Some((headers, payload));
Ok(None)
}
Error::MalformedMessage => {
let id = headers.stream_id();
if self.ctx.borrow_mut().last_stream_id.try_set(id)?.is_none() {
return Ok(None);
}
Err(Error::Reset(Reason::PROTOCOL_ERROR))
}
_ => Err(Error::GoAway(Reason::COMPRESSION_ERROR)),
};
}
if !is_end_headers {
self.continuation = Some((headers, payload));
return Ok(None);
}
let id = headers.stream_id();
self.handle_header_frame(id, headers)
}
fn handle_header_frame(&mut self, id: StreamId, headers: headers::Headers) -> Result<Option<Decoded>, Error> {
let end_stream = headers.is_end_stream();
let (pseudo, headers) = headers.into_parts();
let flow = &mut *self.ctx.borrow_mut();
if flow.last_stream_id.get() >= id {
let stream = flow
.stream_map
.get_mut(&id)
.ok_or(Error::GoAway(Reason::STREAM_CLOSED))?;
match stream.try_recv_trailers(&mut flow.frame_buf, headers, end_stream)? {
RecvData::Queued => {}
_ => {
if end_stream {
flow.try_remove_stream(id);
}
}
}
return Ok(None);
}
if flow.last_stream_id.try_set(id)?.is_none() {
return Ok(None);
}
if flow.stream_map.len() >= self.max_concurrent_streams {
return Err(Error::Reset(Reason::REFUSED_STREAM));
}
let content_length =
BodySize::from_header(&headers, end_stream).map_err(|_| Error::Reset(Reason::PROTOCOL_ERROR))?;
let method = pseudo.method.ok_or(Error::Reset(Reason::PROTOCOL_ERROR))?;
let protocol = pseudo.protocol.map(|proto| Protocol::from_str(&proto));
let is_strict_connect = method == Method::CONNECT && protocol.is_none();
let mut uri_parts = uri::Parts::default();
if let Some(authority) = pseudo.authority {
if let Ok(a) = uri::Authority::from_maybe_shared(authority.into_inner()) {
uri_parts.authority = Some(a);
}
}
match (is_strict_connect, pseudo.scheme) {
(true, Some(_)) | (false, None) => return Err(Error::Reset(Reason::PROTOCOL_ERROR)),
(false, Some(scheme)) if uri_parts.authority.is_some() => {
if let Ok(s) = uri::Scheme::try_from(scheme.as_str()) {
uri_parts.scheme = Some(s);
}
}
_ => {}
}
match (is_strict_connect, pseudo.path) {
(true, Some(_)) | (false, None) => return Err(Error::Reset(Reason::PROTOCOL_ERROR)),
(_, Some(path)) if !path.is_empty() => {
if let Ok(pq) = uri::PathAndQuery::from_maybe_shared(path.into_inner()) {
uri_parts.path_and_query = Some(pq);
}
}
(false, _) => return Err(Error::Reset(Reason::PROTOCOL_ERROR)), _ => {}
}
let ext = Extension::with_protocol(self.addr, protocol);
let mut req = Request::new(RequestExt::from_parts((), ext));
*req.version_mut() = Version::HTTP_2;
*req.headers_mut() = headers;
*req.method_mut() = method;
if let Ok(uri) = Uri::from_parts(uri_parts) {
*req.uri_mut() = uri;
}
let body = RequestBody::new(id, content_length, Rc::clone(self.ctx), self.recv_threshold);
flow.insert_stream(id, end_stream, content_length);
let req = req.map(|ext| ext.map_body(|_| body));
Ok(Some((req, id)))
}
}
async fn read_io<const LIMIT: usize>(mut buf: BytesMut, io: &impl AsyncBufRead) -> (io::Result<usize>, BytesMut) {
if buf.len() >= LIMIT {
return core::future::pending().await;
}
let len = buf.len();
buf.reserve(4096);
let (res, buf) = io.read(buf.slice(len..)).await;
(res, buf.into_inner())
}
async fn write_io(buf: BytesMut, io: &impl AsyncBufWrite) -> (io::Result<()>, BytesMut) {
let (res, mut buf) = write_all(io, buf).await;
buf.clear();
(res, buf)
}
struct EncodeContext<'a> {
encoder: hpack::Encoder,
ctx: &'a Shared,
}
impl<'a> EncodeContext<'a> {
fn new(ctx: &'a Shared) -> Self {
Self {
encoder: hpack::Encoder::new(65535, 4096),
ctx,
}
}
fn poll_encode(&mut self, write_buf: &mut BytesMut) -> Poll<bool> {
let mut flow = self.ctx.borrow_mut();
flow.queue.pending_settings.encode(&mut self.encoder, write_buf);
let writable = loop {
match flow.queue.poll_recv() {
Poll::Ready(Some(msg)) => match msg {
Message::Head(headers) => {
let frame_size = flow.max_frame_size;
let mut cont = headers.encode(&mut self.encoder, &mut write_buf.limit(frame_size));
while let Some(c) = cont {
cont = c.encode(&mut write_buf.limit(frame_size));
}
}
Message::Trailer(headers) => {
let frame_size = flow.max_frame_size;
let mut cont = headers.encode(&mut self.encoder, &mut write_buf.limit(frame_size));
while let Some(c) = cont {
cont = c.encode(&mut write_buf.limit(frame_size));
}
}
Message::Data(mut data) => data.encode_chunk(write_buf),
Message::Reset { stream_id, reason } => Reset::new(stream_id, reason).encode(write_buf),
Message::WindowUpdate { stream_id, size } => {
WindowUpdate::new(stream_id, size as _).encode(write_buf)
}
Message::GoAway { last_stream_id, reason } => {
flow.queue.pending_conn_window = 0;
GoAway::new(last_stream_id, reason).encode(write_buf);
break true;
}
},
Poll::Pending => break true,
Poll::Ready(None) => break false,
}
};
let pending = mem::replace(&mut flow.queue.pending_conn_window, 0);
if pending > 0 {
flow.recv_connection_window += pending;
WindowUpdate::new(StreamId::zero(), pending as _).encode(write_buf);
}
if let Some(payload) = flow.queue.pending_client_ping.take() {
head::Head::new(head::Kind::Ping, 0x1, StreamId::zero()).encode(8, write_buf);
write_buf.put_slice(&payload);
}
if flow.queue.keepalive_ping == KeepalivePing::Pending {
head::Head::new(head::Kind::Ping, 0x0, StreamId::zero()).encode(8, write_buf);
write_buf.put_slice(&[0u8; 8]);
flow.queue.keepalive_ping = KeepalivePing::InFlight;
}
if !write_buf.is_empty() {
Poll::Ready(true)
} else if !writable {
Poll::Ready(false)
} else {
Poll::Pending
}
}
}
async fn response_task<S, ReqB, ResB, ResBE>(
req: Request<RequestExt<RequestBody>>,
stream_id: StreamId,
service: &S,
ctx: &Shared,
date: &DateTimeHandle,
) where
S: Service<Request<RequestExt<ReqB>>, Response = Response<ResB>>,
S::Error: fmt::Debug,
ReqB: From<RequestBody>,
ResB: Body<Data = Bytes, Error = ResBE>,
ResBE: fmt::Debug,
{
let guard = StreamGuard { stream_id, ctx };
let req = req.map(|ext| ext.map_body(From::from));
let res = match service.call(req).await {
Ok(res) => res,
Err(_) => {
let mut flow = ctx.borrow_mut();
let stream = flow.stream_map.get_mut(&stream_id).unwrap();
stream.set_reset(StreamError::InternalError);
return;
}
};
let (mut parts, body) = res.into_parts();
let size = body.size_hint();
if let SizeHint::Exact(size) = size {
parts.headers.insert(CONTENT_LENGTH, size.into());
}
if !parts.headers.contains_key(DATE) {
let date = date.with_date_header(Clone::clone);
parts.headers.insert(DATE, date);
}
let pseudo = headers::Pseudo::response(parts.status);
let mut headers = headers::Headers::new(stream_id, pseudo, parts.headers);
let has_body = !matches!(size, SizeHint::None);
if !has_body {
headers.set_end_stream();
}
ctx.borrow_mut().queue.push(Message::Head(headers));
if has_body {
let mut body = pin!(body);
'body: loop {
match poll_fn(|cx| body.as_mut().poll_frame(cx)).await {
None => break ctx.borrow_mut().queue.push_end_stream(guard.stream_id),
Some(Err(e)) => {
error!("body error: {:?}", e);
let mut flow = guard.ctx.borrow_mut();
if let Some(state) = flow.stream_map.get_mut(&guard.stream_id) {
state.set_reset(StreamError::InternalError);
}
break;
}
Some(Ok(Frame::Data(bytes))) => {
if guard.send_data(bytes, body.is_end_stream()).await {
break 'body;
}
}
Some(Ok(Frame::Trailers(trailers))) => break guard.send_trailers(trailers),
}
}
}
}
struct StreamGuard<'a> {
stream_id: StreamId,
ctx: &'a Shared,
}
impl Drop for StreamGuard<'_> {
fn drop(&mut self) {
self.ctx.borrow_mut().stream_guard_drop(self.stream_id);
}
}
impl StreamGuard<'_> {
async fn send_data(&self, mut data: Bytes, end_stream: bool) -> bool {
if data.is_empty() && !end_stream {
tracing::warn!("response body should not yield empty Frame::Data unless it's the last chunk of Body");
return false;
}
loop {
let len = data.len();
let opt = poll_fn(|cx| {
let mut flow = self.ctx.borrow_mut();
let window = flow.send_connection_window;
flow.stream_map
.get_mut(&self.stream_id)
.expect(STREAM_MUST_EXIST)
.poll_send_window(len, window, cx)
.map_ok(|aval| {
flow.send_connection_window -= aval;
(aval, flow)
})
})
.await;
let Some(Ok((aval, mut flow))) = opt else {
return true;
};
let all_consumed = aval == len;
let payload = if all_consumed {
mem::take(&mut data)
} else {
data.split_to(aval)
};
let end_stream = all_consumed && end_stream;
flow.queue.push_data(self.stream_id, payload, end_stream);
if end_stream {
return true;
} else if all_consumed {
return false;
}
}
}
fn send_trailers(&self, trailers: HeaderMap) {
self.ctx.borrow_mut().queue.push_trailers(self.stream_id, trailers);
}
}
struct PingPong<'a> {
timer: Pin<&'a mut KeepAlive>,
ctx: &'a Shared,
date: &'a DateTimeHandle,
ka_dur: Duration,
}
impl<'a> PingPong<'a> {
fn new(timer: Pin<&'a mut KeepAlive>, ctx: &'a Shared, date: &'a DateTimeHandle, ka_dur: Duration) -> Self {
Self {
timer,
ctx,
date,
ka_dur,
}
}
async fn tick(&mut self) -> io::Result<()> {
self.timer.as_mut().await;
{
let mut inner = self.ctx.borrow_mut();
if inner.queue.keepalive_ping != KeepalivePing::Idle {
return Err(io::Error::new(io::ErrorKind::TimedOut, "h2 ping timeout"));
}
inner.queue.keepalive_ping = KeepalivePing::Pending;
}
self.timer.as_mut().update(self.date.now() + self.ka_dur);
Ok(())
}
}
const PREMATURE_RESET_LIMIT: usize = 100;
const CLIENT_RESET_QUEUE_CAP: usize = 64;
pub(crate) async fn peek_version(
io: &(impl AsyncBufRead + AsyncBufWrite),
buf: BytesMut,
) -> io::Result<(Version, BytesMut)> {
let (read_buf, res) = prefix_check::<4096>(buf, io).await;
let version = if res.is_ok() { Version::HTTP_2 } else { Version::HTTP_11 };
Ok((version, read_buf))
}
pub(crate) async fn run<
Io,
S,
ReqB,
ResB,
ResBE,
const HEADER_LIMIT: usize,
const READ_BUF_LIMIT: usize,
const WRITE_BUF_LIMIT: usize,
>(
io: Io,
addr: SocketAddr,
read_buf: BytesMut,
mut ka: Pin<&mut KeepAlive>,
service: &S,
date: &DateTimeHandle,
config: &HttpServiceConfig<HEADER_LIMIT, READ_BUF_LIMIT, WRITE_BUF_LIMIT>,
) -> io::Result<()>
where
Io: AsyncBufRead + AsyncBufWrite,
S: Service<Request<RequestExt<ReqB>>, Response = Response<ResB>>,
ReqB: From<RequestBody>,
S::Error: fmt::Debug,
ResB: Body<Data = Bytes, Error = ResBE>,
ResBE: fmt::Debug,
{
let mut settings = settings::Settings::default();
settings.set_max_concurrent_streams(Some(config.h2_max_concurrent_streams));
settings.set_initial_window_size(Some(config.h2_initial_window_size));
settings.set_max_frame_size(Some(config.h2_max_frame_size));
settings.set_max_header_list_size(Some(config.h2_max_header_list_size));
settings.set_enable_connect_protocol(Some(1));
let (mut read_buf, mut write_buf) = handshake::<READ_BUF_LIMIT>(&io, read_buf, &settings, ka.as_mut()).await?;
let max_concurrent_streams = config.h2_max_concurrent_streams as usize;
let max_frame_size = config.h2_max_frame_size as usize;
let recv_connection_window = config.h2_initial_window_size as usize;
let shared = Rc::new(RefCell::new(FlowControl {
send_connection_window: settings::DEFAULT_INITIAL_WINDOW_SIZE as usize,
stream_window: settings::DEFAULT_INITIAL_WINDOW_SIZE as i64,
max_frame_size,
recv_connection_window,
stream_map: HashMap::with_capacity_and_hasher(max_concurrent_streams, NoHashBuilder::default()),
premature_reset_count: 0,
last_stream_id: LastStreamId::Incrementable(StreamId::ZERO),
queue: WriterQueue::new(),
frame_buf: FrameBuffer::new(),
}));
let recv_threshold = StreamRecvWindowThreshold::from(&settings);
let mut ctx = DecodeContext::new(&shared, service, max_concurrent_streams, recv_threshold, addr, date);
let mut enc = EncodeContext::new(&shared);
let mut queue = Queue::new();
let mut ping_pong = PingPong::new(ka, &shared, date, config.keep_alive_timeout);
let res = {
let mut read_task = pin!(read_io::<READ_BUF_LIMIT>(read_buf, &io));
let mut write_task = pin!(async {
while poll_fn(|_| enc.poll_encode(&mut write_buf)).await {
let (res, buf) = io.write(write_buf).await;
write_buf = buf;
match res {
Ok(0) => return Err(io::ErrorKind::WriteZero.into()),
Ok(n) => write_buf.advance(n),
Err(e) => return Err(e),
}
}
Ok(())
});
let shutdown = loop {
match read_task
.as_mut()
.select(async {
loop {
if queue.is_empty() && shared.borrow().last_stream_id.is_saturated() {
return;
}
let _ = queue.next().await;
}
})
.select(write_task.as_mut())
.select(ping_pong.tick())
.await
{
SelectOutput::A(SelectOutput::A(SelectOutput::A((res, buf)))) => {
read_buf = buf;
match res {
Ok(n) if n > 0 => {
if let Err(shutdown) = ctx.decode(&mut read_buf, |decoder, (req, id)| {
queue.push(response_task(req, id, decoder.service, decoder.ctx, decoder.date));
}) {
break shutdown;
}
}
res => break ShutDown::ReadClosed(res.map(|_| ())),
};
read_task.set(read_io(read_buf, &io));
}
SelectOutput::A(SelectOutput::A(SelectOutput::B(_))) => break ShutDown::DrainWrite,
SelectOutput::A(SelectOutput::B(res)) => break ShutDown::WriteClosed(res),
SelectOutput::B(Ok(_)) => {}
SelectOutput::B(Err(e)) => break ShutDown::Timeout(e),
}
};
Box::pin(async {
let mut read_res = Ok(());
match shutdown {
ShutDown::WriteClosed(res) => return res,
ShutDown::Timeout(err) => return Err(err),
ShutDown::ReadClosed(res) => {
{
let mut flow = shared.borrow_mut();
for state in flow.stream_map.values_mut() {
state.try_set_peer_reset();
state.recv.set_close_2();
}
flow.queue.close();
}
read_res = res;
}
ShutDown::DrainWrite => queue.clear(),
}
loop {
match queue.next().select(write_task.as_mut()).select(ping_pong.tick()).await {
SelectOutput::A(SelectOutput::A(_)) => {}
SelectOutput::A(SelectOutput::B(res)) => {
res?;
break read_res;
}
SelectOutput::B(res) => res?,
}
}
})
.await
};
let _ = io.shutdown(Shutdown::Write).await;
res
}
enum ShutDown {
ReadClosed(io::Result<()>),
WriteClosed(io::Result<()>),
Timeout(io::Error),
DrainWrite,
}
#[cold]
#[inline(never)]
fn handle_priority(id: StreamId, payload: &[u8]) -> Result<(), Error> {
if id.is_zero() {
Err(Error::GoAway(Reason::PROTOCOL_ERROR))
} else if payload.len() != 5 {
Err(Error::Reset(Reason::FRAME_SIZE_ERROR))
} else if id == StreamId::parse(&payload[..4]).0 {
Err(Error::Reset(Reason::PROTOCOL_ERROR))
} else {
Ok(())
}
}
type BoxedFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;
#[cold]
#[inline(never)]
fn handshake<'a, const LIMIT: usize>(
io: &'a (impl AsyncBufRead + AsyncBufWrite),
buf: BytesMut,
settings: &'a settings::Settings,
timer: Pin<&'a mut KeepAlive>,
) -> BoxedFuture<'a, io::Result<(BytesMut, BytesMut)>> {
Box::pin(async {
async {
let (mut read_buf, res) = prefix_check::<LIMIT>(buf, io).await;
res?;
read_buf.advance(PREFACE.len());
let mut write_buf = BytesMut::new();
settings.encode(&mut write_buf);
let (res, write_buf) = write_io(write_buf, io).await;
res?;
Ok((read_buf, write_buf))
}
.timeout(timer)
.await
.map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "h2 handshake timeout"))?
})
}
async fn prefix_check<const LIMIT: usize>(
mut read_buf: BytesMut,
io: &(impl AsyncBufRead + AsyncBufWrite),
) -> (BytesMut, io::Result<()>) {
while read_buf.len() < PREFACE.len() {
let (res, b) = read_io::<LIMIT>(read_buf, io).await;
read_buf = b;
if res.is_err() {
return (read_buf, res.map(|_| ()));
};
}
let res = if !read_buf.starts_with(PREFACE) {
Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid HTTP/2 client preface",
))
} else {
Ok(())
};
(read_buf, res)
}