use std::{
cell::Cell, cmp, cmp::Ordering, fmt, future::poll_fn, mem, ops, rc::Rc, task::Context,
task::Poll,
};
use ntex_bytes::Bytes;
use ntex_http::{header::CONTENT_LENGTH, HeaderMap, StatusCode};
use ntex_util::task::LocalWaker;
use crate::error::{OperationError, StreamError};
use crate::frame::{
Data, Headers, PseudoHeaders, Reason, Reset, StreamId, WindowSize, WindowUpdate,
};
use crate::{connection::Connection, frame, message::Message, window::Window};
pub struct Stream(StreamRef);
#[derive(Debug)]
pub struct Capacity {
size: Cell<u32>,
stream: Rc<StreamState>,
}
impl Capacity {
fn new(size: u32, stream: &Rc<StreamState>) -> Self {
stream.add_capacity(size);
Self {
size: Cell::new(size),
stream: stream.clone(),
}
}
#[inline]
pub fn size(&self) -> usize {
self.size.get() as usize
}
pub fn consume(&self, sz: u32) {
let size = self.size.get();
if let Some(sz) = size.checked_sub(sz) {
log::trace!(
"{:?} capacity consumed from {} to {}",
self.stream.id,
size,
sz
);
self.size.set(sz);
self.stream.consume_capacity(size - sz);
} else {
panic!("Capacity overflow");
}
}
}
impl ops::Add for Capacity {
type Output = Self;
fn add(self, other: Self) -> Self {
if Rc::ptr_eq(&self.stream, &other.stream) {
let size = Cell::new(self.size.get() + other.size.get());
self.size.set(0);
other.size.set(0);
Self {
size,
stream: self.stream.clone(),
}
} else {
panic!("Cannot add capacity from different streams");
}
}
}
impl ops::AddAssign for Capacity {
fn add_assign(&mut self, other: Self) {
if Rc::ptr_eq(&self.stream, &other.stream) {
let size = self.size.get() + other.size.get();
self.size.set(size);
other.size.set(0);
} else {
panic!("Cannot add capacity from different streams");
}
}
}
impl Drop for Capacity {
fn drop(&mut self) {
let size = self.size.get();
if size > 0 {
self.stream.consume_capacity(size);
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum ContentLength {
Omitted,
Head,
Remaining(u64),
}
#[derive(Clone, Debug)]
pub struct StreamRef(pub(crate) Rc<StreamState>);
bitflags::bitflags! {
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct StreamFlags: u8 {
const REMOTE = 0b0000_0001;
const FAILED = 0b0000_0010;
}
}
pub(crate) struct StreamState {
id: StreamId,
flags: Cell<StreamFlags>,
content_length: Cell<ContentLength>,
recv: Cell<HalfState>,
recv_window: Cell<Window>,
recv_size: Cell<u32>,
send: Cell<HalfState>,
send_window: Cell<Window>,
send_cap: LocalWaker,
send_reset: LocalWaker,
con: Connection,
error: Cell<Option<OperationError>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum HalfState {
Idle,
Payload,
Closed(Option<Reason>),
}
impl HalfState {
pub(crate) fn is_closed(&self) -> bool {
matches!(self, HalfState::Closed(_))
}
}
impl StreamState {
fn state_send_payload(&self) {
self.send.set(HalfState::Payload);
}
fn state_send_close(&self, reason: Option<Reason>) {
log::trace!("{:?} send side is closed with reason {:?}", self.id, reason);
self.send.set(HalfState::Closed(reason));
self.send_cap.wake();
self.review_state();
}
fn state_recv_payload(&self) {
self.recv.set(HalfState::Payload);
}
fn state_recv_close(&self, reason: Option<Reason>) {
log::trace!("{:?} receive side is closed", self.id);
self.recv.set(HalfState::Closed(reason));
self.review_state();
}
fn set_failed(&self) {
let mut flags = self.flags.get();
flags.insert(StreamFlags::FAILED);
self.flags.set(flags);
self.send_cap.wake();
self.send_reset.wake();
}
fn reset_stream(&self, reason: Option<Reason>) {
self.set_failed();
self.recv.set(HalfState::Closed(None));
self.send.set(HalfState::Closed(reason));
if let Some(reason) = reason {
self.error.set(Some(OperationError::LocalReset(reason)));
}
self.review_state();
}
fn remote_reset_stream(&self, reason: Reason) {
self.set_failed();
self.recv.set(HalfState::Closed(Some(reason)));
self.send.set(HalfState::Closed(None));
self.error.set(Some(OperationError::RemoteReset(reason)));
self.review_state();
}
fn failed(&self, err: OperationError) {
self.set_failed();
self.recv.set(HalfState::Closed(None));
self.send.set(HalfState::Closed(None));
self.error.set(Some(err));
self.review_state();
}
fn check_error(&self) -> Result<(), OperationError> {
if let Some(err) = self.error.take() {
self.error.set(Some(err.clone()));
Err(err)
} else {
Ok(())
}
}
fn review_state(&self) {
if self.recv.get().is_closed() {
self.send_reset.wake();
if let HalfState::Closed(reason) = self.send.get() {
if reason.is_some() {
log::trace!("{:?} is closed with local reset, dropping stream", self.id);
} else {
log::trace!("{:?} both sides are closed, dropping stream", self.id);
}
self.con.drop_stream(self.id);
}
}
}
fn add_capacity(&self, size: u32) {
let cap = self.recv_size.get();
self.recv_size.set(cap + size);
self.recv_window.set(self.recv_window.get().dec(size));
log::trace!(
"{:?} capacity incresed from {} to {}",
self.id,
cap,
cap + size
);
self.con.add_capacity(size);
}
fn consume_capacity(&self, size: u32) {
let cap = self.recv_size.get();
let size = cap - size;
log::trace!("{:?} capacity decresed from {} to {}", self.id, cap, size);
self.recv_size.set(size);
let mut window = self.recv_window.get();
if let Some(val) = window.update(
size,
self.con.config().window_sz.get(),
self.con.config().window_sz_threshold.get(),
) {
log::trace!(
"{:?} capacity decresed below threshold {} increase by {} ({})",
self.id,
self.con.config().window_sz_threshold.get(),
val,
self.con.config().window_sz.get(),
);
self.recv_window.set(window);
self.con.encode(WindowUpdate::new(self.id, val));
}
}
}
impl StreamRef {
pub(crate) fn new(id: StreamId, remote: bool, con: Connection) -> Self {
let recv_window = if con.settings_processed() {
Window::new(con.config().window_sz.get() as i32)
} else {
Window::new(frame::DEFAULT_INITIAL_WINDOW_SIZE as i32)
};
let send_window = Window::new(con.remote_window_size() as i32);
StreamRef(Rc::new(StreamState {
id,
con,
recv: Cell::new(HalfState::Idle),
recv_window: Cell::new(recv_window),
recv_size: Cell::new(0),
send: Cell::new(HalfState::Idle),
send_window: Cell::new(send_window),
send_cap: LocalWaker::new(),
send_reset: LocalWaker::new(),
error: Cell::new(None),
content_length: Cell::new(ContentLength::Omitted),
flags: Cell::new(if remote {
StreamFlags::REMOTE
} else {
StreamFlags::empty()
}),
}))
}
#[inline]
pub fn id(&self) -> StreamId {
self.0.id
}
#[inline]
pub fn is_remote(&self) -> bool {
self.0.flags.get().contains(StreamFlags::REMOTE)
}
#[inline]
pub fn is_failed(&self) -> bool {
self.0.flags.get().contains(StreamFlags::FAILED)
}
pub(crate) fn send_state(&self) -> HalfState {
self.0.send.get()
}
pub(crate) fn recv_state(&self) -> HalfState {
self.0.recv.get()
}
#[inline]
pub fn reset(&self, reason: Reason) {
if !self.0.recv.get().is_closed() || !self.0.send.get().is_closed() {
self.0.con.encode(Reset::new(self.0.id, reason));
self.0.reset_stream(Some(reason));
}
}
#[inline]
pub fn empty_capacity(&self) -> Capacity {
Capacity {
size: Cell::new(0),
stream: self.0.clone(),
}
}
#[inline]
pub(crate) fn into_stream(self) -> Stream {
Stream(self)
}
pub(crate) fn send_headers(&self, mut hdrs: Headers) {
hdrs.set_end_headers();
if hdrs.is_end_stream() {
self.0.state_send_close(None);
} else {
self.0.state_send_payload();
}
log::debug!("send headers {:#?} eos: {:?}", hdrs, hdrs.is_end_stream());
if hdrs
.pseudo()
.status
.map_or(false, |status| status.is_informational())
{
self.0.content_length.set(ContentLength::Head)
}
self.0.con.encode(hdrs);
}
pub(crate) fn set_failed(&self, reason: Option<Reason>) {
self.0.reset_stream(reason);
}
pub(crate) fn set_go_away(&self, reason: Reason) {
self.0.remote_reset_stream(reason)
}
pub(crate) fn set_failed_stream(&self, err: OperationError) {
self.0.failed(err);
}
pub(crate) fn recv_headers(&self, hdrs: Headers) -> Result<Option<Message>, StreamError> {
log::debug!(
"processing HEADERS for {:?}:\n{:#?}\nrecv_state:{:?}, send_state: {:?}",
self.0.id,
hdrs,
self.0.recv.get(),
self.0.send.get(),
);
match self.0.recv.get() {
HalfState::Idle => {
let eof = hdrs.is_end_stream();
if eof {
self.0.state_recv_close(None);
} else {
self.0.state_recv_payload();
}
let (pseudo, headers) = hdrs.into_parts();
if self.0.content_length.get() != ContentLength::Head {
if let Some(content_length) = headers.get(CONTENT_LENGTH) {
if let Some(v) = parse_u64(content_length.as_bytes()) {
self.0.content_length.set(ContentLength::Remaining(v));
} else {
proto_err!(stream: "could not parse content-length; stream={:?}", self.0.id);
return Err(StreamError::InvalidContentLength);
}
}
}
Ok(Some(Message::new(pseudo, headers, eof, self)))
}
HalfState::Payload => {
if !hdrs.is_end_stream() {
Err(StreamError::TrailersWithoutEos)
} else {
self.0.state_recv_close(None);
Ok(Some(Message::trailers(hdrs.into_fields(), self)))
}
}
HalfState::Closed(_) => Err(StreamError::Closed),
}
}
pub(crate) fn recv_data(&self, data: Data) -> Result<Option<Message>, StreamError> {
let cap = Capacity::new(data.payload().len() as u32, &self.0);
log::debug!(
"processing DATA frame for {:?}, len: {:?}",
self.0.id,
data.payload().len()
);
match self.0.recv.get() {
HalfState::Payload => {
let eof = data.is_end_stream();
match self.0.content_length.get() {
ContentLength::Remaining(rem) => {
match rem.checked_sub(data.payload().len() as u64) {
Some(val) => {
self.0.content_length.set(ContentLength::Remaining(val));
if eof && val != 0 {
return Err(StreamError::WrongPayloadLength);
}
}
None => return Err(StreamError::WrongPayloadLength),
}
}
ContentLength::Head => {
if !data.payload().is_empty() {
return Err(StreamError::NonEmptyPayload);
}
}
_ => (),
}
if eof {
self.0.state_recv_close(None);
Ok(Some(Message::eof_data(data.into_payload(), self)))
} else {
Ok(Some(Message::data(data.into_payload(), cap, self)))
}
}
HalfState::Idle => Err(StreamError::Idle("DATA framed received")),
HalfState::Closed(_) => Err(StreamError::Closed),
}
}
pub(crate) fn recv_rst_stream(&self, frm: &Reset) {
self.0.remote_reset_stream(frm.reason())
}
pub(crate) fn recv_window_update(&self, frm: WindowUpdate) -> Result<(), StreamError> {
if frm.size_increment() == 0 {
Err(StreamError::WindowZeroUpdateValue)
} else {
let window = self
.0
.send_window
.get()
.inc(frm.size_increment())
.map_err(|_| StreamError::WindowOverflowed)?;
self.0.send_window.set(window);
if window.window_size() > 0 {
self.0.send_cap.wake();
}
Ok(())
}
}
pub(crate) fn update_send_window(&self, upd: i32) -> Result<(), StreamError> {
let orig = self.0.send_window.get();
let window = match upd.cmp(&0) {
Ordering::Less => orig.dec(upd.unsigned_abs()), Ordering::Greater => orig
.inc(upd as u32)
.map_err(|_| StreamError::WindowOverflowed)?,
Ordering::Equal => return Ok(()),
};
log::trace!(
"Updating send window size from {} to {}",
orig.window_size,
window.window_size
);
self.0.send_window.set(window);
Ok(())
}
pub(crate) fn update_recv_window(&self, upd: i32) -> Result<Option<WindowSize>, StreamError> {
let mut window = match upd.cmp(&0) {
Ordering::Less => self.0.recv_window.get().dec(upd.unsigned_abs()), Ordering::Greater => self
.0
.recv_window
.get()
.inc(upd as u32)
.map_err(|_| StreamError::WindowOverflowed)?,
Ordering::Equal => return Ok(None),
};
if let Some(val) = window.update(
self.0.recv_size.get(),
self.0.con.config().window_sz.get(),
self.0.con.config().window_sz_threshold.get(),
) {
self.0.recv_window.set(window);
Ok(Some(val))
} else {
self.0.recv_window.set(window);
Ok(None)
}
}
pub fn send_response(
&self,
status: StatusCode,
headers: HeaderMap,
eof: bool,
) -> Result<(), OperationError> {
match self.0.send.get() {
HalfState::Idle => {
let pseudo = PseudoHeaders::response(status);
let mut hdrs = Headers::new(self.0.id, pseudo, headers, eof);
if eof {
hdrs.set_end_stream();
self.0.state_send_close(None);
} else {
self.0.state_send_payload();
}
self.0.con.encode(hdrs);
Ok(())
}
HalfState::Payload => Err(OperationError::Payload),
HalfState::Closed(r) => Err(OperationError::Closed(r)),
}
}
pub async fn send_payload(&self, mut res: Bytes, eof: bool) -> Result<(), OperationError> {
match self.0.send.get() {
HalfState::Payload => {
self.0.check_error()?;
log::debug!(
"{:?} sending {} bytes, eof: {}, send: {:?}",
self.0.id,
res.len(),
eof,
self.0.send.get()
);
if eof && res.is_empty() {
let mut data = Data::new(self.0.id, Bytes::new());
data.set_end_stream();
self.0.state_send_close(None);
self.0.con.encode(data);
return Ok(());
}
loop {
let win = self.available_send_capacity() as usize;
if win > 0 {
let size =
cmp::min(win, cmp::min(res.len(), self.0.con.remote_frame_size()));
let mut data = if size >= res.len() {
Data::new(self.0.id, mem::replace(&mut res, Bytes::new()))
} else {
log::trace!(
"{:?} sending {} out of {} bytes",
self.0.id,
size,
res.len()
);
Data::new(self.0.id, res.split_to(size))
};
if eof && res.is_empty() {
data.set_end_stream();
self.0.state_send_close(None);
}
self.0
.send_window
.set(self.0.send_window.get().dec(size as u32));
self.0.con.encode(data);
if res.is_empty() {
return Ok(());
}
} else {
log::trace!(
"Not enough sending capacity for {:?} remaining {:?}",
self.0.id,
res.len()
);
self.send_capacity().await?;
}
}
}
HalfState::Idle => Err(OperationError::Idle),
HalfState::Closed(reason) => Err(OperationError::Closed(reason)),
}
}
pub fn send_trailers(&self, map: HeaderMap) {
if self.0.send.get() == HalfState::Payload {
let mut hdrs = Headers::trailers(self.0.id, map);
hdrs.set_end_headers();
hdrs.set_end_stream();
self.0.con.encode(hdrs);
self.0.state_send_close(None);
}
}
pub fn available_send_capacity(&self) -> WindowSize {
self.0.send_window.get().window_size()
}
pub async fn send_capacity(&self) -> Result<WindowSize, OperationError> {
poll_fn(|cx| self.poll_send_capacity(cx)).await
}
pub fn poll_send_capacity(&self, cx: &Context<'_>) -> Poll<Result<WindowSize, OperationError>> {
self.0.check_error()?;
self.0.con.check_error()?;
let win = self.0.send_window.get().window_size();
if win > 0 {
Poll::Ready(Ok(win))
} else {
self.0.send_cap.register(cx.waker());
Poll::Pending
}
}
pub fn poll_send_reset(&self, cx: &Context<'_>) -> Poll<Result<(), OperationError>> {
if self.0.send.get().is_closed() {
Poll::Ready(Ok(()))
} else {
self.0.check_error()?;
self.0.con.check_error()?;
self.0.send_reset.register(cx.waker());
Poll::Pending
}
}
}
impl PartialEq for StreamRef {
fn eq(&self, other: &StreamRef) -> bool {
Rc::as_ptr(&self.0) == Rc::as_ptr(&other.0)
}
}
impl ops::Deref for Stream {
type Target = StreamRef;
#[inline]
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Drop for Stream {
fn drop(&mut self) {
self.0.reset(Reason::CANCEL);
}
}
impl fmt::Debug for Stream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut builder = f.debug_struct("Stream");
builder
.field("stream_id", &self.0 .0.id)
.field("recv_state", &self.0 .0.recv.get())
.field("send_state", &self.0 .0.send.get())
.finish()
}
}
impl fmt::Debug for StreamState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut builder = f.debug_struct("StreamState");
builder
.field("id", &self.id)
.field("recv", &self.recv.get())
.field("recv_window", &self.recv_window.get())
.field("recv_size", &self.recv_size.get())
.field("send", &self.send.get())
.field("send_window", &self.send_window.get())
.field("flags", &self.flags.get())
.finish()
}
}
pub fn parse_u64(src: &[u8]) -> Option<u64> {
if src.len() > 19 {
None
} else {
let mut ret = 0;
for &d in src {
if !d.is_ascii_digit() {
return None;
}
ret *= 10;
ret += (d - b'0') as u64;
}
Some(ret)
}
}