#![expect(
clippy::allow_attributes,
reason = "macro-generated `#[allow]` attributes whose underlying lints fire only for some expansions"
)]
use super::CompressionBody;
use super::CompressionLevel;
use super::body::BodyInner;
use super::predicate::{DefaultPredicate, Predicate, PreferredEncoding};
use crate::headers::encoding::{
AcceptEncoding, Encoding, maybe_preferred_encoding_with_wildcard,
parse_accept_encoding_headers, parse_accept_encoding_wildcard_quality,
};
use crate::layer::remove_header::remove_payload_metadata_headers;
use crate::layer::util::compression::WrapBody;
use crate::{Request, Response, StatusCode, header};
use rama_core::Service;
use rama_core::extensions::ExtensionsRef;
use rama_http_headers::specifier::{Quality, QualityValue};
use rama_http_types::HeaderValue;
use rama_http_types::Method;
use rama_http_types::StreamingBody;
use rama_utils::collections::smallvec::SmallVec;
use rama_utils::macros::define_inner_service_accessors;
use rama_utils::str::submatch_ignore_ascii_case;
#[derive(Debug, Clone)]
pub struct Compression<S, P = DefaultPredicate> {
pub(crate) inner: S,
pub(crate) accept: AcceptEncoding,
pub(crate) predicate: P,
pub(crate) respect_content_encoding_if_possible: bool,
pub(crate) quality: CompressionLevel,
pub(crate) enforce_not_acceptable: bool,
}
impl<S> Compression<S, DefaultPredicate> {
pub fn new(service: S) -> Self {
Self {
inner: service,
accept: AcceptEncoding::default(),
predicate: DefaultPredicate::default(),
respect_content_encoding_if_possible: false,
quality: CompressionLevel::default(),
enforce_not_acceptable: true,
}
}
}
impl<S, P> Compression<S, P> {
define_inner_service_accessors!();
rama_utils::macros::generate_set_and_with! {
pub fn gzip(mut self, enable: bool) -> Self {
self.accept.set_gzip(enable);
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn deflate(mut self, enable: bool) -> Self {
self.accept.set_deflate(enable);
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn br(mut self, enable: bool) -> Self {
self.accept.set_br(enable);
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn zstd(mut self, enable: bool) -> Self {
self.accept.set_zstd(enable);
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn quality(mut self, quality: CompressionLevel) -> Self {
self.quality = quality;
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn respect_content_encoding_if_possible(mut self) -> Self {
self.respect_content_encoding_if_possible = true;
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn enforce_not_acceptable(mut self, enable: bool) -> Self {
self.enforce_not_acceptable = enable;
self
}
}
#[must_use]
pub fn with_compress_predicate<C>(self, predicate: C) -> Compression<S, C>
where
C: Predicate,
{
Compression {
inner: self.inner,
accept: self.accept,
predicate,
respect_content_encoding_if_possible: self.respect_content_encoding_if_possible,
quality: self.quality,
enforce_not_acceptable: self.enforce_not_acceptable,
}
}
}
impl<ReqBody, ResBody, S, P> Service<Request<ReqBody>> for Compression<S, P>
where
S: Service<Request<ReqBody>, Output = Response<ResBody>>,
ResBody: StreamingBody<Data: Send + 'static, Error: Send + 'static> + Send + 'static,
P: Predicate + Send + Sync + 'static,
ReqBody: Send + 'static,
{
type Output = Response<CompressionBody<ResBody>>;
type Error = S::Error;
#[allow(unreachable_code, unused_mut, unused_variables, unreachable_patterns)]
async fn serve(&self, req: Request<ReqBody>) -> Result<Self::Output, Self::Error> {
let accepted_encodings: SmallVec<[QualityValue<Encoding>; 4]> =
parse_accept_encoding_headers(req.headers(), self.accept).collect();
let wildcard_quality = parse_accept_encoding_wildcard_quality(req.headers());
let req_method = req.method().clone();
let mut res = self.inner.serve(req).await?;
let mut respected_encoding = None;
let body_allowed = !matches!(req_method, Method::HEAD | Method::CONNECT)
&& !matches!(res.status().as_u16(), 100..=199 | 204 | 205 | 304);
let should_compress = body_allowed &&
!res.headers().contains_key(header::CONTENT_RANGE) &&
self.predicate.should_compress(&mut res) &&
if self.respect_content_encoding_if_possible {
respected_encoding = Encoding::maybe_from_content_encoding_header(res.headers(), self.accept);
true
} else {
!res.headers().contains_key(header::CONTENT_ENCODING)
};
let negotiated = negotiate_response_encoding(
&accepted_encodings,
wildcard_quality,
self.accept,
respected_encoding,
res.extensions().get_ref::<PreferredEncoding>().copied(),
);
let selected_encoding = match negotiated {
Some(encoding) => encoding,
None if self.enforce_not_acceptable && body_allowed => {
let (mut parts, body) = res.into_parts();
parts.status = StatusCode::NOT_ACCEPTABLE;
ensure_vary_accept_encoding(&mut parts.headers);
return Ok(Response::from_parts(
parts,
CompressionBody::new(BodyInner::identity(body)),
));
}
None => Encoding::Identity,
};
let (mut parts, body) = res.into_parts();
if should_compress {
ensure_vary_accept_encoding(&mut parts.headers);
}
let body = match (should_compress, selected_encoding) {
(false, _) | (_, Encoding::Identity) => {
return Ok(Response::from_parts(
parts,
CompressionBody::new(BodyInner::identity(body)),
));
}
(_, Encoding::Gzip) => {
CompressionBody::new(BodyInner::gzip(WrapBody::new(body, self.quality)))
}
(_, Encoding::Deflate) => {
CompressionBody::new(BodyInner::deflate(WrapBody::new(body, self.quality)))
}
(_, Encoding::Brotli) => {
CompressionBody::new(BodyInner::brotli(WrapBody::new(body, self.quality)))
}
(_, Encoding::Zstd) => {
CompressionBody::new(BodyInner::zstd(WrapBody::new(body, self.quality)))
}
#[allow(unreachable_patterns)]
(true, _) => {
return Ok(Response::from_parts(
parts,
CompressionBody::new(BodyInner::identity(body)),
));
}
};
remove_payload_metadata_headers(&mut parts.headers);
parts.headers.insert(
header::CONTENT_ENCODING,
HeaderValue::from(selected_encoding),
);
let res = Response::from_parts(parts, body);
Ok(res)
}
}
fn negotiate_response_encoding(
accepted_encodings: &[QualityValue<Encoding>],
wildcard_quality: Option<Quality>,
supported: AcceptEncoding,
respected: Option<Encoding>,
preferred: Option<PreferredEncoding>,
) -> Option<Encoding> {
if let Some(respected) = respected
&& accepted_encodings
.iter()
.any(|qval| qval.value == respected && qval.quality.as_u16() > 0)
{
return Some(respected);
}
if let Some(preferred) = preferred.map(PreferredEncoding::as_encoding)
&& accepted_encodings
.iter()
.any(|qval| qval.value == preferred && qval.quality.as_u16() > 0)
{
return Some(preferred);
}
maybe_preferred_encoding_with_wildcard(accepted_encodings, wildcard_quality, supported)
}
fn ensure_vary_accept_encoding(headers: &mut rama_http_types::HeaderMap) {
if !headers.get_all(header::VARY).iter().any(|value| {
submatch_ignore_ascii_case(
value.as_bytes(),
header::ACCEPT_ENCODING.as_str().as_bytes(),
)
}) {
headers.append(header::VARY, header::ACCEPT_ENCODING.into());
}
}