use core::{
cmp, mem,
task::{Context, Poll, Waker},
};
use std::io;
use crate::{
body::SizeHint,
bytes::Bytes,
h2::{
dispatcher::{Frame, FrameBuffer},
util::Deque,
},
http::HeaderMap,
};
use super::{
error::Error,
frame::{reason::Reason, settings},
size::BodySize,
};
pub(crate) struct Stream {
pub(crate) recv: Recv,
pub(crate) send: Send,
pending_reset: PendingReset,
}
impl Stream {
#[allow(dead_code)]
const HEAD_METHOD: u8 = 1 << 3;
pub(crate) fn new(send_window: i64, send_frame_size: usize, content_length: SizeHint, end_stream: bool) -> Self {
let (window, state) = if end_stream {
(0, State::Eof)
} else {
(settings::DEFAULT_INITIAL_WINDOW_SIZE as usize, State::Open)
};
Self {
recv: Recv {
queue: Deque::new(),
waker: None,
window,
state,
content_length,
},
send: Send {
window: send_window,
frame_size: send_frame_size,
waker: None,
state: State::Open,
},
pending_reset: PendingReset::None,
}
}
pub(crate) fn try_recv_data(
&mut self,
buffer: &mut FrameBuffer,
data: Bytes,
end_stream: bool,
) -> Result<RecvData, Error> {
self.recvable()?;
let len = data.len();
let recv = match self.data_check(len, end_stream) {
Ok(_) => {
if self.recv.state.is_open() {
self.recv.push_frame(buffer, Frame::Data(data), end_stream);
RecvData::Queued
} else {
self.recv.window += len;
RecvData::Discard(len)
}
}
Err(err) => {
self.set_reset(err);
RecvData::StreamReset(len)
}
};
Ok(recv)
}
pub(crate) fn try_recv_trailers(
&mut self,
buffer: &mut FrameBuffer,
trailers: HeaderMap,
end_stream: bool,
) -> Result<RecvData, Error> {
self.recvable()?;
let recv = match self.trailers_check(end_stream) {
Ok(_) => {
if self.recv.state.is_open() {
self.recv.push_frame(buffer, Frame::Trailers(trailers), true);
RecvData::Queued
} else {
RecvData::Discard(0)
}
}
Err(err) => {
self.set_reset(err);
RecvData::StreamReset(0)
}
};
Ok(recv)
}
pub(crate) fn poll_send_window(
&mut self,
len: usize,
window: usize,
cx: &mut Context<'_>,
) -> Poll<Option<Result<usize, StreamError>>> {
self.send.poll_send_window(len, window, cx)
}
pub(crate) fn maybe_close_recv(&mut self, buffer: &mut FrameBuffer) -> RecvClose {
self.recv.set_close();
let mut window = 0;
while let Some(frame) = self.recv.queue.pop_front(buffer) {
if let Frame::Data(bytes) = frame {
window += bytes.len();
}
}
self.recv.window += window;
if self.recv.state.is_close() {
RecvClose::Close(window)
} else {
RecvClose::Cancel(window)
}
}
pub(crate) fn is_send_close(&self) -> bool {
self.send.state.is_close()
}
pub(crate) fn try_remove(&mut self) -> Option<Remove> {
(self.is_send_close() && self.recv.state.is_close()).then(|| match self.pending_reset.take() {
Some(reason) => Remove::Reset(reason),
None => Remove::Graceful,
})
}
pub(crate) fn set_reset(&mut self, err: StreamError) {
self.pending_reset.try_set_local(err.reason());
self.recv.try_set_err(err);
self.send.try_set_err(err);
}
pub(crate) fn try_set_peer_reset(&mut self) {
self.pending_reset.try_set_peer();
self.recv.try_set_err(StreamError::PeerReset);
self.send.try_set_err(StreamError::PeerReset);
}
pub(crate) fn take_recv_err(&self) -> Option<io::Error> {
self.recv.state.take_error()
}
fn recvable(&self) -> Result<(), Error> {
match (&self.pending_reset, &self.recv.state) {
(PendingReset::Peer, _) => Err(Error::GoAway(Reason::STREAM_CLOSED)),
(_, State::Eof) => Err(Error::GoAway(Reason::PROTOCOL_ERROR)),
(_, _) => Ok(()),
}
}
fn data_check(&mut self, len: usize, end_stream: bool) -> Result<(), StreamError> {
if len == 0 && !end_stream {
return Err(StreamError::EmptyDataNoEndStream);
}
self.recv
.content_length
.dec(len)
.map_err(|_| StreamError::ContentLengthOverflow)?;
if end_stream {
self.ensure_zero()?;
}
self.try_window_dec(len)
}
fn trailers_check(&self, end_stream: bool) -> Result<(), StreamError> {
if !end_stream {
return Err(StreamError::TrailersNoEndStream);
}
self.ensure_zero()
}
fn try_window_dec(&mut self, len: usize) -> Result<(), StreamError> {
match self.recv.window.checked_sub(len) {
Some(window) => {
self.recv.window = window;
Ok(())
}
None => Err(StreamError::FlowControlOverflow),
}
}
fn ensure_zero(&self) -> Result<(), StreamError> {
self.recv
.content_length
.ensure_zero()
.map_err(|_| StreamError::ContentLengthUnderflow)
}
}
pub(crate) enum RecvClose {
Cancel(usize),
Close(usize),
}
pub(crate) enum Remove {
Graceful,
Reset(Reason),
}
pub(crate) enum RecvData {
Queued,
Discard(usize),
StreamReset(usize),
}
pub(crate) struct Recv {
pub(crate) queue: Deque,
pub(crate) waker: Option<Waker>,
pub(crate) window: usize,
state: State,
content_length: SizeHint,
}
impl Recv {
fn try_set_err(&mut self, err: StreamError) {
if self.state.is_open() {
self.state = State::Error(err);
self.wake();
}
}
fn push_frame(&mut self, buffer: &mut FrameBuffer, frame: Frame, end_stream: bool) {
self.queue.push_back(buffer, frame);
if end_stream {
self.state = State::Eof;
}
self.wake();
}
fn set_close(&mut self) {
self.state = match self.state {
State::Open => State::Cancel,
_ => State::Close,
}
}
pub(crate) fn is_eof(&self) -> bool {
matches!(self.state, State::Eof)
}
pub(crate) fn set_close_2(&mut self) {
match self.state {
State::Error(_) => {}
_ => self.state = State::Close,
}
}
fn wake(&mut self) {
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
}
pub(crate) struct Send {
pub(crate) window: i64,
pub(crate) frame_size: usize,
pub(crate) waker: Option<Waker>,
state: State,
}
impl Send {
fn poll_send_window(
&mut self,
len: usize,
window: usize,
cx: &mut Context<'_>,
) -> Poll<Option<Result<usize, StreamError>>> {
let opt = match self.state {
State::Error(err) => Some(Err(err)),
State::Close => None,
_ => {
if len > 0 && (window == 0 || self.window <= 0) {
self.waker = Some(cx.waker().clone());
return Poll::Pending;
}
let len = cmp::min(len, self.frame_size);
let aval = cmp::min(self.window as usize, window);
let aval = cmp::min(aval, len);
self.window -= aval as i64;
Some(Ok(aval))
}
};
Poll::Ready(opt)
}
fn try_set_err(&mut self, err: StreamError) {
if self.state.is_open() {
self.state = State::Error(err);
self.wake();
}
}
pub(crate) fn set_close(&mut self) {
self.state = State::Close;
}
pub(crate) fn wake(&mut self) {
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
}
enum PendingReset {
None,
Peer,
Local(Reason),
}
impl PendingReset {
fn try_set_peer(&mut self) {
if matches!(self, Self::None) {
*self = Self::Peer;
}
}
fn try_set_local(&mut self, reason: Reason) {
if matches!(self, Self::None) {
*self = Self::Local(reason);
}
}
fn take(&mut self) -> Option<Reason> {
match mem::replace(self, PendingReset::None) {
PendingReset::Local(reason) => Some(reason),
_ => None,
}
}
}
enum State {
Open,
Cancel,
Eof,
Error(StreamError),
Close,
}
impl State {
fn is_open(&self) -> bool {
matches!(self, State::Open)
}
fn is_close(&self) -> bool {
matches!(self, State::Close)
}
fn take_error(&self) -> Option<io::Error> {
match *self {
State::Error(err) => Some(err.into()),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy)]
pub(crate) enum StreamError {
EmptyDataNoEndStream,
TrailersNoEndStream,
ContentLengthOverflow,
ContentLengthUnderflow,
FlowControlOverflow,
PeerReset,
WindowUpdateZeroIncrement,
WindowUpdateOverflow,
InternalError,
}
impl StreamError {
pub(crate) fn reason(&self) -> Reason {
match self {
Self::FlowControlOverflow | Self::WindowUpdateOverflow => Reason::FLOW_CONTROL_ERROR,
Self::PeerReset => Reason::NO_ERROR,
Self::InternalError => Reason::INTERNAL_ERROR,
_ => Reason::PROTOCOL_ERROR,
}
}
}
impl From<StreamError> for io::Error {
fn from(err: StreamError) -> Self {
let msg = match err {
StreamError::EmptyDataNoEndStream => "empty DATA without END_STREAM",
StreamError::TrailersNoEndStream => "trailer HEADERS without END_STREAM",
StreamError::ContentLengthOverflow => "content-length exceeded",
StreamError::ContentLengthUnderflow => "content-length underflow at END_STREAM",
StreamError::FlowControlOverflow => "stream flow control overflow",
StreamError::PeerReset => "h2 stream reset by peer",
StreamError::WindowUpdateZeroIncrement => "WINDOW_UPDATE with zero increment",
StreamError::WindowUpdateOverflow => "WINDOW_UPDATE caused window overflow",
StreamError::InternalError => "ineternal error",
};
io::Error::new(io::ErrorKind::InvalidData, msg)
}
}