use super::CompressionBody;
use super::CompressionLevel;
use super::body::BodyInner;
use super::predicate::{DefaultPredicate, Predicate};
use crate::dep::http_body::Body;
use crate::headers::encoding::{AcceptEncoding, Encoding};
use crate::layer::util::compression::WrapBody;
use crate::{Request, Response, header};
use rama_core::{Context, Service};
use rama_http_types::HeaderValue;
use rama_utils::macros::define_inner_service_accessors;
pub struct Compression<S, P = DefaultPredicate> {
pub(crate) inner: S,
pub(crate) accept: AcceptEncoding,
pub(crate) predicate: P,
pub(crate) quality: CompressionLevel,
}
impl<S, P> std::fmt::Debug for Compression<S, P>
where
S: std::fmt::Debug,
P: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Compression")
.field("inner", &self.inner)
.field("accept", &self.accept)
.field("predicate", &self.predicate)
.field("quality", &self.quality)
.finish()
}
}
impl<S, P> Clone for Compression<S, P>
where
S: Clone,
P: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
accept: self.accept,
predicate: self.predicate.clone(),
quality: self.quality,
}
}
}
impl<S> Compression<S, DefaultPredicate> {
pub fn new(service: S) -> Compression<S, DefaultPredicate> {
Self {
inner: service,
accept: AcceptEncoding::default(),
predicate: DefaultPredicate::default(),
quality: CompressionLevel::default(),
}
}
}
impl<S, P> Compression<S, P> {
define_inner_service_accessors!();
pub fn gzip(mut self, enable: bool) -> Self {
self.accept.set_gzip(enable);
self
}
pub fn set_gzip(&mut self, enable: bool) -> &mut Self {
self.accept.set_gzip(enable);
self
}
pub fn deflate(mut self, enable: bool) -> Self {
self.accept.set_deflate(enable);
self
}
pub fn set_deflate(&mut self, enable: bool) -> &mut Self {
self.accept.set_deflate(enable);
self
}
pub fn br(mut self, enable: bool) -> Self {
self.accept.set_br(enable);
self
}
pub fn set_br(&mut self, enable: bool) -> &mut Self {
self.accept.set_br(enable);
self
}
pub fn zstd(mut self, enable: bool) -> Self {
self.accept.set_zstd(enable);
self
}
pub fn set_zstd(&mut self, enable: bool) -> &mut Self {
self.accept.set_zstd(enable);
self
}
pub fn quality(mut self, quality: CompressionLevel) -> Self {
self.quality = quality;
self
}
pub fn set_quality(&mut self, quality: CompressionLevel) -> &mut Self {
self.quality = quality;
self
}
pub fn compress_when<C>(self, predicate: C) -> Compression<S, C>
where
C: Predicate,
{
Compression {
inner: self.inner,
accept: self.accept,
predicate,
quality: self.quality,
}
}
}
impl<ReqBody, ResBody, S, P, State> Service<State, Request<ReqBody>> for Compression<S, P>
where
S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
ResBody: Body<Data: Send + 'static, Error: Send + 'static> + Send + 'static,
P: Predicate + Send + Sync + 'static,
ReqBody: Send + 'static,
State: Clone + Send + Sync + 'static,
{
type Response = Response<CompressionBody<ResBody>>;
type Error = S::Error;
#[allow(unreachable_code, unused_mut, unused_variables, unreachable_patterns)]
async fn serve(
&self,
ctx: Context<State>,
req: Request<ReqBody>,
) -> Result<Self::Response, Self::Error> {
let encoding = Encoding::from_accept_encoding_headers(req.headers(), self.accept);
let res = self.inner.serve(ctx, req).await?;
let should_compress = !res.headers().contains_key(header::CONTENT_ENCODING)
&& !res.headers().contains_key(header::CONTENT_RANGE)
&& self.predicate.should_compress(&res);
let (mut parts, body) = res.into_parts();
if should_compress {
parts
.headers
.append(header::VARY, header::ACCEPT_ENCODING.into());
}
let body = match (should_compress, encoding) {
(false, _) | (_, Encoding::Identity) => {
return Ok(Response::from_parts(
parts,
CompressionBody::new(BodyInner::identity(body)),
));
}
(_, Encoding::Gzip) => {
CompressionBody::new(BodyInner::gzip(WrapBody::new(body, self.quality)))
}
(_, Encoding::Deflate) => {
CompressionBody::new(BodyInner::deflate(WrapBody::new(body, self.quality)))
}
(_, Encoding::Brotli) => {
CompressionBody::new(BodyInner::brotli(WrapBody::new(body, self.quality)))
}
(_, Encoding::Zstd) => {
CompressionBody::new(BodyInner::zstd(WrapBody::new(body, self.quality)))
}
#[allow(unreachable_patterns)]
(true, _) => {
return Ok(Response::from_parts(
parts,
CompressionBody::new(BodyInner::identity(body)),
));
}
};
parts.headers.remove(header::ACCEPT_RANGES);
parts.headers.remove(header::CONTENT_LENGTH);
parts
.headers
.insert(header::CONTENT_ENCODING, HeaderValue::from(encoding));
let res = Response::from_parts(parts, body);
Ok(res)
}
}