use core::{
cell::RefCell,
mem,
task::{Context, Poll, Waker, ready},
time::Duration,
};
use std::{
collections::{HashMap, VecDeque},
io,
rc::Rc,
};
use xitca_unsafe_collection::no_hash::NoHashBuilder;
use crate::{
body::SizeHint,
bytes::{BufMut, Bytes, BytesMut},
http::{Method, Protocol, Request, Version, header::HeaderMap, uri},
};
use super::{
error::Error as ProtoError,
frame::{
data::Data,
go_away::GoAway,
head,
headers::{Headers, ResponsePseudo},
ping::Ping,
priority::Priority,
reason::Reason,
reset::Reset,
settings::{self, Settings},
stream_id::StreamId,
window_update::WindowUpdate,
},
hpack,
last_stream_id::LastStreamId,
reset_counter::ResetCounter,
size::BodySize,
stream::{RecvData, Stream, StreamError, TryRemove},
threshold::RecvWindowThreshold,
window::{RecvWindow, SendWindow},
};
const STREAM_MUST_EXIST: &str = "Stream MUST NOT be removed while RequestBody or response_task is still alive";
pub(crate) type Frame = crate::body::Frame<Bytes>;
pub(crate) type FrameBuffer = crate::h2::util::FrameBuffer<Frame>;
pub(crate) type DecodedRequest = (Request<(SizeHint, Option<Protocol>)>, StreamId);
pub(crate) type FlowControlClone = Rc<FlowControlLock>;
pub(crate) type FlowControlLock = RefCell<FlowControl>;
pub(crate) struct FlowControl {
max_concurrent_streams: usize,
max_frame_size: SendWindow,
send_connection_window: SendWindow,
send_stream_initial_window: SendWindow,
recv_stream_initial_window: RecvWindow,
recv_connection_window: RecvWindow,
recv_threshold: RecvWindowThreshold,
stream_map: HashMap<StreamId, Stream, NoHashBuilder>,
reset_counter: ResetCounter,
last_stream_id: LastStreamId,
queue: WriterQueue,
frame_buf: FrameBuffer,
}
impl FlowControl {
const RESET_MAX: usize = 20;
const RESET_WINDOW: Duration = Duration::from_secs(30);
pub(crate) fn new(settings: &Settings) -> Self {
let recv_stream_initial_window = RecvWindow::new(settings.initial_window_size().unwrap());
let max_concurrent_streams = settings.max_concurrent_streams().unwrap() as _;
let recv_threshold = RecvWindowThreshold::from(settings);
Self {
max_concurrent_streams,
send_connection_window: SendWindow::default(),
send_stream_initial_window: SendWindow::default(),
max_frame_size: SendWindow::from_u32(settings::DEFAULT_MAX_FRAME_SIZE),
recv_stream_initial_window,
recv_connection_window: RecvWindow::default(),
recv_threshold,
stream_map: HashMap::with_capacity_and_hasher(max_concurrent_streams, NoHashBuilder::default()),
reset_counter: ResetCounter::new(Self::RESET_MAX, Self::RESET_WINDOW),
last_stream_id: LastStreamId::new(),
queue: WriterQueue::new(),
frame_buf: FrameBuffer::new(),
}
}
fn check_not_idle(&self, id: StreamId) -> Result<(), Error> {
if self.last_stream_id.check_idle(id) {
Err(Error::GoAway(Reason::PROTOCOL_ERROR))
} else {
Ok(())
}
}
fn recv_window_dec(&mut self, len: RecvWindow) -> Result<(), Error> {
self.recv_connection_window
.checked_sub(len)
.map_err(|_| Error::GoAway(Reason::FLOW_CONTROL_ERROR))
}
fn update_and_wake_send_streams(&mut self, delta: SendWindow) {
let conn_window_positive = self.send_connection_window.is_positive();
for stream in self.stream_map.values_mut() {
stream.send_window_update(delta, conn_window_positive);
}
}
pub(crate) fn request_body_drop(&mut self, id: StreamId, pending_window: RecvWindow) {
let stream = self.stream_map.get_mut(&id).expect(STREAM_MUST_EXIST);
let window = stream.maybe_close_recv(&mut self.frame_buf) + pending_window;
let mut wake = window != RecvWindow::ZERO;
self.queue.connection_window_update(window);
let remove = stream.try_remove();
if let Err(err) = self.remove_stream(id, remove) {
wake = true;
if let Err(err) = self.try_push_reset(id, err.reason()) {
self.go_away(err);
}
}
if wake {
self.queue.wake();
}
}
pub(crate) fn response_task_done(&mut self, id: StreamId) -> Result<(), ()> {
let stream = self.stream_map.get_mut(&id).expect(STREAM_MUST_EXIST);
stream.close_send();
let remove = stream.try_remove();
if let Err(err) = self.remove_stream(id, remove) {
if let Err(err) = self.try_push_reset(id, err.reason()) {
self.go_away(err);
return Err(());
}
}
Ok(())
}
fn inline_remove_stream(&mut self, id: StreamId) -> Result<(), Error> {
if let Some(stream) = self.stream_map.get_mut(&id) {
stream.promote_cancel_to_close_recv();
let remove = stream.try_remove();
self.remove_stream(id, remove)?;
};
Ok(())
}
fn remove_stream(&mut self, id: StreamId, remove: TryRemove) -> Result<(), StreamError> {
let res = match remove {
TryRemove::Keep => return Ok(()),
TryRemove::ResetKeep(err) => return Err(err),
TryRemove::ResetRemove(err) => Err(err),
TryRemove::Remove => Ok(()),
};
self.stream_map.remove(&id);
res
}
#[cold]
#[inline(never)]
pub(crate) fn try_push_reset(&mut self, id: StreamId, reason: Reason) -> Result<(), Error> {
if !matches!(reason, Reason::INTERNAL_ERROR | Reason::NO_ERROR) {
self.try_tick_reset()?;
}
self.queue.push(Message::Reset { stream_id: id, reason });
Ok(())
}
pub(crate) fn recv_header(&mut self, id: StreamId, headers: Headers) -> Result<Option<DecodedRequest>, Error> {
let end_stream = headers.is_end_stream();
let (pseudo, headers) = headers.into_parts();
if !self.last_stream_id.check_idle(id) {
let stream = self
.stream_map
.get_mut(&id)
.ok_or(Error::GoAway(Reason::STREAM_CLOSED))?;
match stream.try_recv_trailers(&mut self.frame_buf, headers, end_stream)? {
RecvData::Queued(_) => {}
_ => self.inline_remove_stream(id)?,
}
return Ok(None);
}
if self.try_set_last_stream_id(id)?.is_none() {
return Ok(None);
}
if self.stream_map.len() >= self.max_concurrent_streams {
return Err(Error::Reset(Reason::REFUSED_STREAM));
}
let content_length =
BodySize::from_header(&headers, end_stream).map_err(|_| Error::Reset(Reason::PROTOCOL_ERROR))?;
let method = pseudo.method.ok_or(Error::Reset(Reason::PROTOCOL_ERROR))?;
let protocol = pseudo.protocol.map(|proto| Protocol::from_str(&proto));
let is_strict_connect = method == Method::CONNECT && protocol.is_none();
let mut uri_parts = uri::Parts::default();
if let Some(authority) = pseudo.authority {
if let Ok(a) = uri::Authority::from_maybe_shared(authority.into_inner()) {
uri_parts.authority = Some(a);
}
}
match (is_strict_connect, pseudo.scheme) {
(true, Some(_)) | (false, None) => return Err(Error::Reset(Reason::PROTOCOL_ERROR)),
(false, Some(scheme)) if uri_parts.authority.is_some() => {
if let Ok(s) = uri::Scheme::try_from(scheme.as_str()) {
uri_parts.scheme = Some(s);
}
}
_ => {}
}
match (is_strict_connect, pseudo.path) {
(true, Some(_)) | (false, None) => return Err(Error::Reset(Reason::PROTOCOL_ERROR)),
(_, Some(path)) if !path.is_empty() => {
if let Ok(pq) = uri::PathAndQuery::from_maybe_shared(path.into_inner()) {
uri_parts.path_and_query = Some(pq);
}
}
(false, _) => return Err(Error::Reset(Reason::PROTOCOL_ERROR)), _ => {}
}
let mut req = Request::new((content_length, protocol));
*req.version_mut() = Version::HTTP_2;
*req.headers_mut() = headers;
*req.method_mut() = method;
if let Ok(uri) = uri::Uri::from_parts(uri_parts) {
*req.uri_mut() = uri;
}
let stream = Stream::new(
self.send_stream_initial_window,
self.recv_stream_initial_window,
content_length,
end_stream,
);
self.stream_map.insert(id, stream);
Ok(Some((req, id)))
}
pub(crate) fn try_set_last_stream_id(&mut self, id: StreamId) -> Result<Option<()>, Error> {
self.last_stream_id.try_set(id).map_err(Into::into)
}
pub(crate) fn recv_data(&mut self, data: Data) -> Result<(), Error> {
let id = data.stream_id();
self.check_not_idle(id)?;
let flow_len = data.flow_controlled_len() as u32;
let flow_len = RecvWindow::new(flow_len);
self.recv_window_dec(flow_len)?;
let stream = self.stream_map.get_mut(&id).ok_or_else(|| {
self.queue.connection_window_update(flow_len);
Error::Reset(Reason::STREAM_CLOSED)
})?;
let end_stream = data.is_end_stream();
let data = data.into_payload();
let (conn_window, stream_window, want_remove) =
match stream.try_recv_data(&mut self.frame_buf, data, flow_len, end_stream)? {
RecvData::Queued(size) => {
(size, size, false)
}
RecvData::Discard(size) => {
let stream_window = if !end_stream { size } else { RecvWindow::ZERO };
(size, stream_window, end_stream)
}
RecvData::StreamReset(size) => (size, RecvWindow::ZERO, true),
};
self.queue.connection_window_update(conn_window);
self.queue.stream_window_update(id, stream_window);
if want_remove {
self.inline_remove_stream(id)?;
}
Ok(())
}
pub(crate) fn recv_window_update(&mut self, window: WindowUpdate) -> Result<(), Error> {
let id = window.stream_id();
self.check_not_idle(id)?;
match (window.size_increment(), id) {
(0, StreamId::ZERO) => return Err(Error::GoAway(Reason::PROTOCOL_ERROR)),
(0, id) => {
if let Some(state) = self.stream_map.get_mut(&id) {
state.try_set_reset(StreamError::WindowUpdateZeroIncrement);
}
}
(incr, StreamId::ZERO) => {
let was_zero = self.send_connection_window == SendWindow::ZERO;
self.send_connection_window
.try_inc(SendWindow::from_u32(incr))
.map_err(|_| Error::GoAway(Reason::FLOW_CONTROL_ERROR))?;
if was_zero {
self.update_and_wake_send_streams(SendWindow::ZERO);
}
}
(incr, id) => {
let conn_window_positive = self.send_connection_window.is_positive();
if let Some(stream) = self.stream_map.get_mut(&id) {
stream.try_send_window_update(SendWindow::from_u32(incr), conn_window_positive);
}
}
}
Ok(())
}
pub(crate) fn recv_ping(&mut self, ping: Ping) {
if ping.is_ack {
self.queue.keepalive_ping = KeepalivePing::Idle;
} else {
self.queue.pending_client_ping = Some(ping.payload);
}
}
#[cold]
#[inline(never)]
pub(crate) fn recv_reset(&mut self, reset: Reset) -> Result<(), Error> {
let id = reset.stream_id();
if id.is_zero() {
return Err(Error::GoAway(Reason::PROTOCOL_ERROR));
}
self.check_not_idle(id)?;
self.try_tick_reset()?;
let Some(stream) = self.stream_map.get_mut(&id) else {
return Ok(());
};
stream.try_set_peer_reset();
self.inline_remove_stream(id)
}
#[cold]
#[inline(never)]
pub(crate) fn internal_reset(&mut self, id: &StreamId) {
self.stream_map
.get_mut(id)
.expect(STREAM_MUST_EXIST)
.try_set_reset(StreamError::InternalError);
}
#[cold]
#[inline(never)]
pub(crate) fn go_away(&mut self, err: Error) -> bool {
let Error::GoAway(reason) = err else {
unreachable!("Error::Reset MUST not be handled as GO_AWAY frame")
};
if let Some(last_stream_id) = self.last_stream_id.try_go_away() {
self.queue.push(Message::GoAway { last_stream_id, reason });
}
let fatal = reason != Reason::NO_ERROR;
if fatal {
self.queue.close();
}
fatal
}
pub(crate) fn poll_stream_frame(
&mut self,
id: &StreamId,
pending_window: &mut RecvWindow,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame, StreamError>>> {
let stream = self.stream_map.get_mut(id).expect(STREAM_MUST_EXIST);
stream.poll_frame(&mut self.frame_buf, cx).map_ok(|frame| {
if let Some(bytes) = frame.data_ref() {
*pending_window += bytes.len() as u32;
if *pending_window >= self.recv_threshold {
let window = mem::replace(pending_window, RecvWindow::ZERO);
stream.recv_window_update(window);
self.queue.connection_window_update(window);
self.queue.stream_window_update(*id, window);
}
}
frame
})
}
pub(crate) fn is_recv_end_stream(&self, id: &StreamId) -> bool {
self.stream_map.get(id).expect(STREAM_MUST_EXIST).is_recv_end_stream()
}
pub(super) fn try_set_pending_ping(&mut self) -> io::Result<()> {
self.queue.keepalive_ping.try_set_pending_ping()
}
fn try_tick_reset(&mut self) -> Result<(), Error> {
if self.reset_counter.tick() {
Err(Error::GoAway(Reason::ENHANCE_YOUR_CALM))
} else {
Ok(())
}
}
#[cold]
#[inline(never)]
pub(crate) fn recv_priority(&mut self, head: head::Head, frame: &BytesMut) -> Result<(), Error> {
Priority::load(head, frame)?;
Ok(())
}
#[cold]
#[inline(never)]
pub(crate) fn recv_setting(&mut self, head: head::Head, frame: &BytesMut) -> Result<(), Error> {
let setting = Settings::load(head, frame)?;
if setting.is_ack() {
return Ok(());
}
if let Some(new_window) = setting.initial_window_size() {
let new_initial = SendWindow::new(new_window as i32);
let delta = new_initial - self.send_stream_initial_window;
self.send_stream_initial_window = new_initial;
if delta > SendWindow::ZERO {
for stream in self.stream_map.values() {
stream
.send_window_check(delta)
.map_err(|err| Error::GoAway(err.reason()))?;
}
}
if delta != SendWindow::ZERO {
self.update_and_wake_send_streams(delta);
}
}
if let Some(frame_size) = setting.max_frame_size() {
self.max_frame_size = SendWindow::new(frame_size as i32);
}
self.queue.pending_settings.try_update(setting)
}
pub(crate) fn poll_encode(
&mut self,
write_buf: &mut BytesMut,
encoder: &mut hpack::Encoder,
cx: &mut Context<'_>,
) -> Poll<bool> {
self.queue.pending_settings.encode(encoder, write_buf);
while let Some(msg) = self.queue.try_recv() {
match msg {
Message::Head(headers) => {
let frame_size = self.max_frame_size.as_frame_size();
let mut cont = headers.encode(encoder, &mut write_buf.limit(frame_size));
while let Some(c) = cont {
cont = c.encode(&mut write_buf.limit(frame_size));
}
}
Message::Trailer(headers) => {
let frame_size = self.max_frame_size.as_frame_size();
let mut cont = headers.encode(encoder, &mut write_buf.limit(frame_size));
while let Some(c) = cont {
cont = c.encode(&mut write_buf.limit(frame_size));
}
}
Message::Data(mut data) => data.encode_chunk(write_buf),
Message::Reset { stream_id, reason } => Reset::new(stream_id, reason).encode(write_buf),
Message::WindowUpdate { stream_id, size } => {
WindowUpdate::new(stream_id, size.value()).encode(write_buf)
}
Message::GoAway { last_stream_id, reason } => {
GoAway::new(last_stream_id, reason).encode(write_buf);
}
Message::Settings(settings) => settings.encode(write_buf),
}
}
let pending = mem::replace(&mut self.queue.pending_conn_window, RecvWindow::ZERO);
if pending != RecvWindow::ZERO {
self.recv_connection_window += pending;
WindowUpdate::new(StreamId::zero(), pending.value()).encode(write_buf);
}
if let Some(payload) = self.queue.pending_client_ping.take() {
Ping::new(payload, true).encode(write_buf);
}
self.queue.keepalive_ping.encode(write_buf);
if !write_buf.is_empty() {
Poll::Ready(true)
} else if self.queue.is_closed() {
Poll::Ready(false)
} else {
self.queue.register(cx);
Poll::Pending
}
}
pub(crate) fn send_headers(&mut self, headers: Headers<ResponsePseudo>) {
self.queue.push(Message::Head(headers));
}
pub(crate) fn poll_send_data(
&mut self,
id: StreamId,
data: &mut Bytes,
end_stream: bool,
cx: &mut Context<'_>,
) -> Poll<Option<()>> {
if data.is_empty() {
let opt = if !end_stream {
tracing::warn!("Empty Data frame is not allowed unless it's the last frame of stream");
None
} else {
let payload = mem::take(data);
self.queue.push_data(id, payload, end_stream);
Some(())
};
return Poll::Ready(opt);
}
let stream = self.stream_map.get_mut(&id).expect(STREAM_MUST_EXIST);
loop {
let len = data.len();
let req = SendWindow::from_usize_saturating(len).min(self.max_frame_size);
let Some(Ok(aval)) = ready!(stream.poll_send_window(req, &mut self.send_connection_window, cx)) else {
return Poll::Ready(Some(()));
};
let aval = aval.as_frame_size();
let all_consumed = aval == len;
let payload = if all_consumed {
mem::take(data)
} else {
data.split_to(aval)
};
let end_stream = all_consumed && end_stream;
self.queue.push_data(id, payload, end_stream);
if end_stream {
return Poll::Ready(Some(()));
} else if all_consumed {
return Poll::Ready(None);
}
}
}
pub(crate) fn send_trailers(&mut self, id: StreamId, trailers: HeaderMap) {
self.queue.push_trailers(id, trailers);
}
pub(crate) fn send_end_stream(&mut self, id: StreamId) {
self.queue.push_end_stream(id);
}
pub(crate) fn close_write_queue(&mut self) {
self.queue.close();
}
pub(crate) fn reset_all_stream(&mut self, res: &io::Result<()>) {
let stream_err = if res.is_ok() {
StreamError::GoAway
} else {
StreamError::Io
};
for stream in self.stream_map.values_mut() {
stream.try_set_reset(stream_err);
}
}
pub(crate) fn init(&mut self, settings: Settings) {
let delta = self.recv_stream_initial_window.saturating_sub(RecvWindow::default());
self.queue.connection_window_update(delta);
self.queue.push(Message::Settings(settings));
}
}
struct WriterQueue {
messages: VecDeque<Message>,
closed: bool,
pending_settings: RemoteSettings,
pending_conn_window: RecvWindow,
keepalive_ping: KeepalivePing,
pending_client_ping: Option<[u8; 8]>,
waker: Option<Waker>,
}
impl WriterQueue {
fn new() -> Self {
Self {
messages: VecDeque::new(),
closed: false,
pending_settings: RemoteSettings::default(),
pending_conn_window: RecvWindow::ZERO,
keepalive_ping: KeepalivePing::Idle,
pending_client_ping: None,
waker: None,
}
}
fn push(&mut self, msg: Message) {
self.messages.push_back(msg);
}
fn connection_window_update(&mut self, size: RecvWindow) {
self.pending_conn_window += size;
}
fn stream_window_update(&mut self, id: StreamId, size: RecvWindow) {
if size != RecvWindow::ZERO {
self.push(Message::WindowUpdate { stream_id: id, size })
}
}
fn push_data(&mut self, id: StreamId, payload: Bytes, end_stream: bool) {
let mut data = Data::new(id, payload);
data.set_end_stream(end_stream);
self.push(Message::Data(data));
}
fn push_trailers(&mut self, id: StreamId, trailers: HeaderMap) {
let trailer = Headers::trailers(id, trailers);
self.push(Message::Trailer(trailer));
}
fn push_end_stream(&mut self, stream_id: StreamId) {
for msg in self.messages.iter_mut().rev() {
if let Message::Data(d) = msg {
if d.stream_id() == stream_id {
d.set_end_stream(true);
return;
}
}
}
self.push_data(stream_id, Bytes::new(), true);
}
fn close(&mut self) {
self.closed = true;
}
fn try_recv(&mut self) -> Option<Message> {
self.messages.pop_front()
}
fn is_closed(&self) -> bool {
self.closed
}
fn register(&mut self, cx: &mut Context<'_>) {
if self
.waker
.as_ref()
.filter(|waker| waker.will_wake(cx.waker()))
.is_none()
{
self.waker = Some(cx.waker().clone());
}
}
fn wake(&mut self) {
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
}
#[derive(Default)]
struct RemoteSettings {
header_table_size: Option<Option<u32>>,
}
impl RemoteSettings {
#[cold]
#[inline(never)]
fn try_update(&mut self, settings: Settings) -> Result<(), Error> {
if self.header_table_size.is_some() {
return Err(Error::GoAway(Reason::ENHANCE_YOUR_CALM));
}
self.header_table_size = Some(settings.header_table_size());
Ok(())
}
fn encode(&mut self, encoder: &mut hpack::Encoder, buf: &mut BytesMut) {
if let Some(header_table_size) = self.header_table_size.take() {
if let Some(size) = header_table_size {
encoder.update_max_size(size as usize);
}
Settings::ack().encode(buf);
}
}
}
enum KeepalivePing {
Idle,
Pending,
InFlight,
}
impl KeepalivePing {
fn encode(&mut self, write_buf: &mut BytesMut) {
if matches!(self, KeepalivePing::Pending) {
Ping::new([0u8; 8], false).encode(write_buf);
*self = KeepalivePing::InFlight;
}
}
fn try_set_pending_ping(&mut self) -> io::Result<()> {
if !matches!(self, KeepalivePing::Idle) {
return Err(io::Error::new(io::ErrorKind::TimedOut, "h2 ping timeout"));
}
*self = KeepalivePing::Pending;
Ok(())
}
}
pub(crate) enum Error {
Reset(Reason),
GoAway(Reason),
}
impl From<ProtoError> for Error {
fn from(e: ProtoError) -> Self {
if e.is_go_away() {
Self::GoAway(e.reason())
} else {
Self::Reset(e.reason())
}
}
}
impl From<StreamError> for Error {
fn from(err: StreamError) -> Self {
Self::Reset(err.reason())
}
}
enum Message {
Head(Headers<ResponsePseudo>),
Data(Data),
Trailer(Headers<()>),
Reset { stream_id: StreamId, reason: Reason },
WindowUpdate { stream_id: StreamId, size: RecvWindow },
GoAway { last_stream_id: StreamId, reason: Reason },
Settings(Settings),
}