use std::pin::pin;
use std::task::Context;
use std::{cell::RefCell, marker::PhantomData, pin::Pin, ptr, rc::Rc, task::Poll};
use futures::{AsyncRead, Stream, StreamExt};
use std::future::Future;
use super::{
dynamic_exchange::{DynamicExchange, ExchangeEvent},
entity::Chunk,
};
use crate::event::{BodyEvent, Event};
pub struct BodyStreamExchange<B: Event> {
#[cfg(feature = "experimental")]
host: Option<Rc<dyn crate::host::Host>>,
contains_body: bool,
#[allow(unused)]
exchange: Rc<RefCell<DynamicExchange>>,
stream: RefCell<Option<Pin<Box<dyn Stream<Item = Chunk>>>>>,
_event: PhantomData<B>,
}
impl<B: ExchangeEvent + BodyEvent + 'static> BodyStreamExchange<B> {
#[allow(clippy::await_holding_refcell_ref)]
pub(super) async fn new(exchange: Rc<RefCell<DynamicExchange>>, contains_body: bool) -> Self {
if !contains_body {
return Self {
#[cfg(feature = "experimental")]
host: None,
contains_body,
exchange,
stream: RefCell::new(None),
_event: PhantomData,
};
};
let mut ref_mut = exchange.as_ref().borrow_mut();
let ex = ref_mut.wait_for_event_buffering::<B>(false).await;
let stream = ex.map(|e| {
Box::pin(
e.static_event_data_stream()
.map(|e| Chunk::new(e.read_body(0, e.chunk_size()))),
) as Pin<Box<dyn Stream<Item = Chunk>>>
});
#[cfg(feature = "experimental")]
let host = ex.map(|e| e.host.clone());
drop(ref_mut);
Self {
#[cfg(feature = "experimental")]
host,
contains_body,
exchange,
stream: RefCell::new(stream),
_event: PhantomData,
}
}
pub(super) fn contains_body(&self) -> bool {
self.contains_body
}
pub(super) fn stream(&self) -> BodyStream {
let mut stream = self.stream.borrow_mut();
BodyStream {
inner: stream.take(),
_lifetime: PhantomData,
}
}
#[cfg(feature = "experimental")]
pub(super) fn write_chunk(&self, chunk: Chunk) -> Result<(), super::BodyError> {
let Some(host) = &self.host else {
return Err(super::BodyError::BodyNotSent);
};
#[cfg(not(feature = "experimental_disable_body_limit_check"))]
if chunk.bytes().len() >= crate::hl::body::MAX_BODY_SIZE {
return Err(super::BodyError::ExceededBodySize(chunk.bytes().len()));
}
B::write_body(host.as_ref(), 0, usize::MAX, chunk.bytes());
Ok(())
}
}
pub struct BodyStream<'b> {
inner: Option<Pin<Box<dyn Stream<Item = Chunk>>>>,
_lifetime: PhantomData<&'b ()>,
}
impl BodyStream<'_> {
pub async fn next(&mut self) -> Option<Chunk> {
StreamExt::next(self).await
}
pub async fn collect(&mut self) -> Chunk {
let mut bytes = Vec::<u8>::new();
while let Some(chunk) = self.next().await {
bytes.append(&mut chunk.into_bytes());
}
Chunk::new(bytes)
}
}
impl<'b> Stream for BodyStream<'b>
where
Self: 'b,
{
type Item = Chunk;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let Some(inner) = self.inner.as_mut() else {
return Poll::Ready(None);
};
inner.as_mut().poll_next(cx)
}
}
pub struct BodyStreamAsyncReader<'a> {
stream: RefCell<BodyStream<'a>>,
chunk: RefCell<Option<Chunk>>,
last: RefCell<usize>,
}
impl<'a> BodyStreamAsyncReader<'a> {
pub fn new(stream: BodyStream<'a>) -> Self {
Self {
stream: RefCell::new(stream),
chunk: RefCell::new(None),
last: RefCell::new(0),
}
}
}
impl AsyncRead for BodyStreamAsyncReader<'_> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
let mut last = *self.last.borrow();
let mut chunk_len = self
.chunk
.borrow()
.as_ref()
.map(|c| c.bytes().len())
.unwrap_or_default();
if last == chunk_len {
match pin!(self.stream.borrow_mut().next()).poll(cx) {
Poll::Ready(Some(chunk)) => {
last = 0;
chunk_len = chunk.bytes().len();
*self.chunk.borrow_mut() = Some(chunk);
}
Poll::Ready(None) => return Poll::Ready(Ok(0)),
Poll::Pending => return Poll::Pending,
}
};
let mut read = 0;
if let Some(chunk) = self.chunk.borrow().as_ref() {
let new_last = match chunk_len - last <= buf.len() {
true => chunk_len,
false => last + buf.len(),
};
read = new_last - last;
unsafe {
ptr::copy_nonoverlapping(
chunk.bytes()[last..new_last].as_ptr(),
buf.as_mut_ptr(),
read,
);
}
*self.last.borrow_mut() = new_last;
}
Poll::Ready(Ok(read))
}
}