use std::collections::VecDeque;
use std::task::{Context, Poll, Waker};
use std::{cell::Cell, cell::RefCell, fmt, future::poll_fn, pin::Pin, rc::Rc, rc::Weak};
use ntex_h2::{self as h2};
use crate::util::{Bytes, Stream};
use crate::{http::error::PayloadError, task::LocalWaker};
bitflags::bitflags! {
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
struct Flags: u8 {
const EOF = 0b0000_0001;
const DROPPED = 0b0000_0010;
}
}
#[derive(Debug)]
pub struct Payload {
inner: Rc<Inner>,
}
impl Payload {
pub fn create(cap: h2::Capacity) -> (PayloadSender, Payload) {
let shared = Rc::new(Inner::new(cap));
(
PayloadSender {
inner: Rc::downgrade(&shared),
},
Payload { inner: shared },
)
}
#[inline]
pub async fn read(&self) -> Option<Result<Bytes, PayloadError>> {
poll_fn(|cx| self.poll_read(cx)).await
}
#[inline]
pub fn poll_read(
&self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, PayloadError>>> {
self.inner.readany(cx)
}
}
impl Drop for Payload {
fn drop(&mut self) {
self.inner.io_task.wake();
self.inner.insert_flags(Flags::DROPPED);
}
}
impl Stream for Payload {
type Item = Result<Bytes, PayloadError>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, PayloadError>>> {
self.inner.readany(cx)
}
}
#[derive(Debug)]
pub struct PayloadSender {
inner: Weak<Inner>,
}
impl Drop for PayloadSender {
fn drop(&mut self) {
if let Some(shared) = self.inner.upgrade()
&& !shared.flags.get().contains(Flags::EOF)
{
self.set_error(PayloadError::Incomplete(None));
}
}
}
impl PayloadSender {
pub fn set_error(&self, err: PayloadError) {
if let Some(shared) = self.inner.upgrade() {
shared.set_error(err);
}
}
pub fn feed_eof(&self, data: Bytes) {
if let Some(shared) = self.inner.upgrade() {
shared.feed_eof(data);
}
}
pub fn feed_data(&self, data: Bytes, cap: h2::Capacity) {
if let Some(shared) = self.inner.upgrade() {
shared.feed_data(data, cap);
}
}
pub fn set_stream(&self, stream: Option<h2::Stream>) {
if let Some(shared) = self.inner.upgrade() {
shared.stream.set(stream);
}
}
pub(crate) fn on_cancel(&self, w: &Waker) -> Poll<()> {
if let Some(shared) = self.inner.upgrade() {
if shared.flags.get().contains(Flags::DROPPED) {
Poll::Ready(())
} else {
shared.io_task.register(w);
Poll::Pending
}
} else {
Poll::Ready(())
}
}
}
struct Inner {
flags: Cell<Flags>,
cap: Cell<Option<h2::Capacity>>,
err: Cell<Option<PayloadError>>,
items: RefCell<VecDeque<Bytes>>,
task: LocalWaker,
io_task: LocalWaker,
stream: Cell<Option<h2::Stream>>,
}
impl Inner {
fn new(cap: h2::Capacity) -> Self {
Inner {
cap: Cell::new(Some(cap)),
flags: Cell::new(Flags::empty()),
err: Cell::new(None),
stream: Cell::new(None),
items: RefCell::new(VecDeque::new()),
task: LocalWaker::new(),
io_task: LocalWaker::new(),
}
}
fn insert_flags(&self, f: Flags) {
let mut flags = self.flags.get();
flags.insert(f);
self.flags.set(flags);
}
fn set_error(&self, err: PayloadError) {
self.err.set(Some(err));
self.task.wake();
}
fn feed_eof(&self, data: Bytes) {
self.insert_flags(Flags::EOF);
if !data.is_empty() {
self.items.borrow_mut().push_back(data);
}
self.task.wake();
}
fn feed_data(&self, data: Bytes, cap: h2::Capacity) {
self.cap.set(Some(self.cap.take().unwrap() + cap));
self.items.borrow_mut().push_back(data);
self.task.wake();
}
fn readany(&self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, PayloadError>>> {
if let Some(data) = self.items.borrow_mut().pop_front() {
if !self.flags.get().contains(Flags::EOF) {
let cap = self.cap.take().unwrap();
cap.consume(data.len() as u32);
let size = cap.size();
self.cap.set(Some(cap));
if size == 0 {
self.task.register(cx.waker());
}
}
Poll::Ready(Some(Ok(data)))
} else if let Some(err) = self.err.take() {
Poll::Ready(Some(Err(err)))
} else if self.flags.get().contains(Flags::EOF) {
Poll::Ready(None)
} else {
self.task.register(cx.waker());
self.io_task.wake();
Poll::Pending
}
}
}
impl fmt::Debug for Inner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let cap = self.cap.take().unwrap();
let err = self.err.take();
let result = f
.debug_struct("Inner")
.field("flags", &self.flags.get())
.field("capacity", &cap)
.field("error", &err)
.field("items", &self.items.borrow())
.finish();
self.cap.set(Some(cap));
self.err.set(err);
result
}
}