use std::fmt;
use rama_core::error::BoxError;
use rama_core::extensions::Extensions;
use rama_core::{Layer, Service};
use rama_utils::macros::{define_inner_service_accessors, generate_set_and_with};
use super::HtmlRewriteBody;
use crate::headers::ContentType;
use crate::layer::remove_header::{
remove_cache_validation_response_headers, remove_payload_metadata_headers,
};
use crate::layer::util::rewrite_policy::BodyRewritePolicy;
use crate::protocols::html::rewrite::ElementContentHandler;
use crate::protocols::html::selector::Selector;
use crate::{HeaderMap, Request, Response, StreamingBody};
#[derive(Clone)]
pub struct HtmlRewrite<S, H> {
pub(crate) inner: S,
pub(crate) selectors: Box<[Selector]>,
pub(crate) handler: H,
policy: BodyRewritePolicy,
}
impl<S, H> HtmlRewrite<S, H> {
pub fn new(inner: S, selectors: impl IntoIterator<Item = Selector>, handler: H) -> Self {
Self {
inner,
selectors: selectors.into_iter().collect(),
handler,
policy: BodyRewritePolicy::unencoded_content_type(is_html_content_type),
}
}
generate_set_and_with! {
pub fn rewrite_policy(
mut self,
policy: impl Fn(&HeaderMap, &Extensions) -> bool + Send + Sync + 'static,
) -> Self {
self.policy = BodyRewritePolicy::custom(policy);
self
}
}
define_inner_service_accessors!();
}
impl<S: fmt::Debug, H> fmt::Debug for HtmlRewrite<S, H> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HtmlRewrite")
.field("inner", &self.inner)
.field("selectors", &self.selectors)
.field("handler", &std::any::type_name::<H>())
.field("policy", &self.policy)
.finish()
}
}
impl<S, H, ReqBody, ResBody> Service<Request<ReqBody>> for HtmlRewrite<S, H>
where
S: Service<Request<ReqBody>, Output = Response<ResBody>>,
ResBody: StreamingBody<Data: Send + 'static, Error: Into<BoxError> + Send + 'static>
+ Send
+ 'static,
H: ElementContentHandler + Clone + Send + Sync + 'static,
ReqBody: Send + 'static,
{
type Output = Response<HtmlRewriteBody<ResBody, H>>;
type Error = S::Error;
async fn serve(&self, req: Request<ReqBody>) -> Result<Self::Output, Self::Error> {
let res = self.inner.serve(req).await?;
let (mut parts, body) = res.into_parts();
let rewrite = !self.selectors.is_empty()
&& self
.policy
.should_rewrite(&parts.headers, &parts.extensions);
let body = if rewrite {
remove_payload_metadata_headers(&mut parts.headers);
remove_cache_validation_response_headers(&mut parts.headers);
HtmlRewriteBody::new(body, &self.selectors, self.handler.clone())
} else {
HtmlRewriteBody::passthrough(body)
};
Ok(Response::from_parts(parts, body))
}
}
fn is_html_content_type(content_type: &ContentType) -> bool {
content_type.mime().essence_str() == "text/html"
}
#[derive(Clone)]
pub struct HtmlRewriteLayer<H> {
selectors: Box<[Selector]>,
handler: H,
policy: BodyRewritePolicy,
}
impl<H> HtmlRewriteLayer<H> {
pub fn new(selectors: impl IntoIterator<Item = Selector>, handler: H) -> Self {
Self {
selectors: selectors.into_iter().collect(),
handler,
policy: BodyRewritePolicy::unencoded_content_type(is_html_content_type),
}
}
generate_set_and_with! {
pub fn rewrite_policy(
mut self,
policy: impl Fn(&HeaderMap, &Extensions) -> bool + Send + Sync + 'static,
) -> Self {
self.policy = BodyRewritePolicy::custom(policy);
self
}
}
pub fn rewrite_body<B>(&self, body: B) -> HtmlRewriteBody<B, H>
where
H: ElementContentHandler + Clone,
{
HtmlRewriteBody::new(body, &self.selectors, self.handler.clone())
}
}
impl<H: fmt::Debug> fmt::Debug for HtmlRewriteLayer<H> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HtmlRewriteLayer")
.field("selectors", &self.selectors)
.field("handler", &self.handler)
.field("policy", &self.policy)
.finish()
}
}
impl<S, H: Clone> Layer<S> for HtmlRewriteLayer<H> {
type Service = HtmlRewrite<S, H>;
fn layer(&self, inner: S) -> Self::Service {
HtmlRewrite {
inner,
selectors: self.selectors.clone(),
handler: self.handler.clone(),
policy: self.policy.clone(),
}
}
fn into_layer(self, inner: S) -> Self::Service {
HtmlRewrite {
inner,
selectors: self.selectors,
handler: self.handler,
policy: self.policy,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::headers::HeaderMapExt;
use crate::{HeaderMap, header};
#[test]
fn rewrite_content_type_policy() {
let cases = [
("text/html", true),
("text/html; charset=utf-8", true),
("application/xhtml+xml", false),
("application/json", false),
];
for (content_type, expected) in cases {
let mut headers = HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
content_type.parse().expect("valid header"),
);
let content_type = headers.typed_get::<ContentType>().expect("content type");
assert_eq!(
is_html_content_type(&content_type),
expected,
"{content_type}"
);
}
}
#[test]
fn rewrite_policy_skips_content_encoded_html() {
let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, "text/html".parse().unwrap());
headers.insert(header::CONTENT_ENCODING, "gzip".parse().unwrap());
let policy = BodyRewritePolicy::unencoded_content_type(is_html_content_type);
assert!(!policy.should_rewrite(&headers, &Extensions::new()));
}
}