use super::{DecompressedFrom, DecompressionBody, body::BodyInner};
use crate::headers::encoding::{AcceptEncoding, SupportedEncodings};
use crate::layer::remove_header::remove_payload_metadata_headers;
use crate::layer::util::compression::{CompressionLevel, WrapBody};
use crate::{
Request, Response, StreamingBody,
header::{self, ACCEPT_ENCODING},
};
use rama_core::error::{BoxError, ErrorContext as _};
use rama_core::{
Service,
matcher::service::{ServiceMatch, ServiceMatcher},
};
use rama_utils::macros::define_inner_service_accessors;
use std::convert::Infallible;
#[derive(Debug, Clone)]
pub struct Decompression<S, M = DefaultDecompressionMatcher> {
pub(crate) inner: S,
pub(crate) accept: AcceptEncoding,
pub(crate) insert_accept_encoding_header: bool,
pub(crate) tolerate_decode_errors: bool,
pub(crate) matcher: M,
}
impl<S> Decompression<S> {
pub fn new(service: S) -> Self {
Self {
inner: service,
accept: AcceptEncoding::default(),
insert_accept_encoding_header: true,
tolerate_decode_errors: false,
matcher: DefaultDecompressionMatcher,
}
}
}
impl<S, M> Decompression<S, M> {
define_inner_service_accessors!();
rama_utils::macros::generate_set_and_with! {
pub fn insert_accept_encoding_header(mut self, insert: bool) -> Self {
self.insert_accept_encoding_header = insert;
self
}
}
pub fn with_matcher<T>(self, matcher: T) -> Decompression<S, T> {
Decompression {
inner: self.inner,
accept: self.accept,
insert_accept_encoding_header: self.insert_accept_encoding_header,
tolerate_decode_errors: self.tolerate_decode_errors,
matcher,
}
}
rama_utils::macros::generate_set_and_with! {
pub fn tolerate_decode_errors(mut self, tolerate: bool) -> Self {
self.tolerate_decode_errors = tolerate;
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn gzip(mut self, enable: bool) -> Self {
self.accept.set_gzip(enable);
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn deflate(mut self, enable: bool) -> Self {
self.accept.set_deflate(enable);
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn br(mut self, enable: bool) -> Self {
self.accept.set_br(enable);
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn zstd(mut self, enable: bool) -> Self {
self.accept.set_zstd(enable);
self
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct DefaultDecompressionMatcher;
#[derive(Debug, Clone, Copy, Default)]
pub struct DefaultResponseDecompressionMatcher;
impl<ReqBody> ServiceMatcher<Request<ReqBody>> for DefaultDecompressionMatcher
where
ReqBody: Send + 'static,
{
type Service = DefaultResponseDecompressionMatcher;
type Error = Infallible;
type ModifiedInput = Request<ReqBody>;
async fn match_service(
&self,
input: Request<ReqBody>,
) -> Result<ServiceMatch<Self::ModifiedInput, Self::Service>, Self::Error> {
Ok(ServiceMatch {
input,
service: Some(DefaultResponseDecompressionMatcher),
})
}
}
impl<ResBody> ServiceMatcher<Response<ResBody>> for DefaultResponseDecompressionMatcher
where
ResBody: Send + 'static,
{
type Service = ();
type Error = Infallible;
type ModifiedInput = Response<ResBody>;
async fn match_service(
&self,
input: Response<ResBody>,
) -> Result<ServiceMatch<Self::ModifiedInput, Self::Service>, Self::Error> {
Ok(ServiceMatch {
input,
service: Some(()),
})
}
}
impl<S, M, ReqBody, ResBody> Service<Request<ReqBody>> for Decompression<S, M>
where
S: Service<Request<ReqBody>, Output = Response<ResBody>, Error: Into<BoxError>>,
M: ServiceMatcher<
Request<ReqBody>,
ModifiedInput = Request<ReqBody>,
Service: ServiceMatcher<
Response<ResBody>,
ModifiedInput = Response<ResBody>,
Service = (),
Error: Into<BoxError>,
>,
Error: Into<BoxError>,
>,
ReqBody: Send + 'static,
ResBody: StreamingBody<Data: Send + 'static, Error: Send + 'static> + Send + 'static,
{
type Output = Response<DecompressionBody<ResBody>>;
type Error = BoxError;
async fn serve(&self, req: Request<ReqBody>) -> Result<Self::Output, Self::Error> {
let ServiceMatch {
input: mut req,
service: maybe_response_matcher,
} = self
.matcher
.match_service(req)
.await
.context("decompression matcher: request")?;
if self.insert_accept_encoding_header
&& let header::Entry::Vacant(entry) = req.headers_mut().entry(ACCEPT_ENCODING)
&& let Some(accept) = self.accept.maybe_to_header_value()
{
entry.insert(accept);
}
let res = self.inner.serve(req).await.context("inner::serve")?;
let ServiceMatch {
input: res,
service: should_decompress,
} = if let Some(response_matcher) = maybe_response_matcher {
response_matcher
.into_match_service(res)
.await
.context("decompression matcher: response")?
} else {
ServiceMatch {
input: res,
service: None,
}
};
let (mut parts, body) = res.into_parts();
let res = if should_decompress.is_some()
&& let header::Entry::Occupied(entry) = parts.headers.entry(header::CONTENT_ENCODING)
{
let maybe_marker = match entry.get().as_bytes() {
b"gzip" if self.accept.gzip() => Some(DecompressedFrom::Gzip),
b"deflate" if self.accept.deflate() => Some(DecompressedFrom::Deflate),
b"br" if self.accept.br() => Some(DecompressedFrom::Brotli),
b"zstd" if self.accept.zstd() => Some(DecompressedFrom::Zstd),
_ => None,
};
let Some(marker) = maybe_marker else {
return Ok(Response::from_parts(
parts,
DecompressionBody::new(BodyInner::identity(body)),
));
};
let body = match marker {
DecompressedFrom::Gzip => DecompressionBody::new(BodyInner::gzip(
WrapBody::new(body, CompressionLevel::default())
.with_tolerate_decode_errors(self.tolerate_decode_errors),
)),
DecompressedFrom::Deflate => DecompressionBody::new(BodyInner::deflate(
WrapBody::new(body, CompressionLevel::default())
.with_tolerate_decode_errors(self.tolerate_decode_errors),
)),
DecompressedFrom::Brotli => DecompressionBody::new(BodyInner::brotli(
WrapBody::new(body, CompressionLevel::default())
.with_tolerate_decode_errors(self.tolerate_decode_errors),
)),
DecompressedFrom::Zstd => DecompressionBody::new(BodyInner::zstd(
WrapBody::new(body, CompressionLevel::default())
.with_tolerate_decode_errors(self.tolerate_decode_errors),
)),
};
entry.remove();
remove_payload_metadata_headers(&mut parts.headers);
parts.extensions.insert(marker);
Response::from_parts(parts, body)
} else {
Response::from_parts(parts, DecompressionBody::new(BodyInner::identity(body)))
};
Ok(res)
}
}