use std::pin::Pin;
use std::task::{Context, Poll};
use pin_project_lite::pin_project;
use rama_core::bytes::{Buf, Bytes};
use rama_core::error::BoxError;
use rama_core::futures::ready;
use crate::HeaderMap;
use crate::body::{Frame, SizeHint, StreamingBody};
use crate::protocols::html::rewrite::{ElementContentHandler, HtmlRewriter};
use crate::protocols::html::selector::Selector;
type OnEnd<H> = Box<dyn FnOnce(H) + Send + Sync>;
pin_project! {
pub struct HtmlRewriteBody<B, H> {
#[pin]
inner: B,
rewriter: Option<HtmlRewriter<H>>,
on_end: Option<OnEnd<H>>,
pending_trailers: Option<HeaderMap>,
done: bool,
}
}
impl<B, H> HtmlRewriteBody<B, H>
where
H: ElementContentHandler,
{
pub fn new(inner: B, selectors: &[Selector], handler: H) -> Self {
Self {
inner,
rewriter: Some(HtmlRewriter::new(selectors, handler)),
on_end: None,
pending_trailers: None,
done: false,
}
}
}
impl<B, H> HtmlRewriteBody<B, H> {
pub fn passthrough(inner: B) -> Self {
Self {
inner,
rewriter: None,
on_end: None,
pending_trailers: None,
done: false,
}
}
#[must_use]
pub fn on_end<F>(mut self, on_end: F) -> Self
where
F: FnOnce(H) + Send + Sync + 'static,
{
self.on_end = Some(Box::new(on_end));
self
}
}
fn fire_on_end<H: ElementContentHandler>(
rewriter: &mut Option<HtmlRewriter<H>>,
on_end: &mut Option<OnEnd<H>>,
) {
if let (Some(rewriter), Some(on_end)) = (rewriter.take(), on_end.take()) {
on_end(rewriter.into_handler());
}
}
impl<B, H> StreamingBody for HtmlRewriteBody<B, H>
where
B: StreamingBody<Error: Into<BoxError>>,
H: ElementContentHandler,
{
type Data = Bytes;
type Error = BoxError;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let mut this = self.project();
if let Some(trailers) = this.pending_trailers.take() {
*this.done = true;
return Poll::Ready(Some(Ok(Frame::trailers(trailers))));
}
if *this.done {
return Poll::Ready(None);
}
let Some(rewriter) = this.rewriter.as_mut() else {
return match ready!(this.inner.as_mut().poll_frame(cx)) {
Some(Ok(frame)) => Poll::Ready(Some(Ok(normalize_frame(frame)))),
Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
None => {
*this.done = true;
Poll::Ready(None)
}
};
};
loop {
match ready!(this.inner.as_mut().poll_frame(cx)) {
Some(Ok(frame)) => match frame.into_data() {
Ok(mut data) => {
while data.has_remaining() {
let chunk = data.chunk();
let len = chunk.len();
if let Err(err) = rewriter.write(chunk) {
return Poll::Ready(Some(Err(err)));
}
data.advance(len);
}
let out = rewriter.take_output();
if !out.is_empty() {
return Poll::Ready(Some(Ok(Frame::data(Bytes::from(out)))));
}
}
Err(frame) => {
if let Ok(trailers) = frame.into_trailers() {
if let Err(err) = rewriter.end() {
return Poll::Ready(Some(Err(err)));
}
let out = rewriter.take_output();
fire_on_end(this.rewriter, this.on_end);
if out.is_empty() {
*this.done = true;
return Poll::Ready(Some(Ok(Frame::trailers(trailers))));
}
*this.pending_trailers = Some(trailers);
return Poll::Ready(Some(Ok(Frame::data(Bytes::from(out)))));
}
}
},
Some(Err(err)) => return Poll::Ready(Some(Err(err.into()))),
None => {
*this.done = true;
if let Err(err) = rewriter.end() {
return Poll::Ready(Some(Err(err)));
}
let out = rewriter.take_output();
fire_on_end(this.rewriter, this.on_end);
return if out.is_empty() {
Poll::Ready(None)
} else {
Poll::Ready(Some(Ok(Frame::data(Bytes::from(out)))))
};
}
}
}
}
fn size_hint(&self) -> SizeHint {
if self.rewriter.is_some() {
SizeHint::default()
} else {
self.inner.size_hint()
}
}
}
fn normalize_frame<D: Buf>(frame: Frame<D>) -> Frame<Bytes> {
match frame.into_data() {
Ok(mut data) => Frame::data(data.copy_to_bytes(data.remaining())),
Err(frame) => match frame.into_trailers() {
Ok(trailers) => Frame::trailers(trailers),
Err(_) => Frame::data(Bytes::new()),
},
}
}