use super::body::BodyInner;
use super::{CompressionBody, CompressionLayer};
use crate::compression::predicate::{DefaultPredicate, Predicate};
use crate::compression::CompressionLevel;
use crate::compression_utils::WrapBody;
use crate::{compression_utils::AcceptEncoding, content_encoding::Encoding};
use http::{header, Request, Response};
use http_body::Body;
use tower_async_service::Service;
#[derive(Clone, Copy)]
pub struct Compression<S, P = DefaultPredicate> {
pub(crate) inner: S,
pub(crate) accept: AcceptEncoding,
pub(crate) predicate: P,
pub(crate) quality: CompressionLevel,
}
impl<S> Compression<S, DefaultPredicate> {
pub fn new(service: S) -> Compression<S, DefaultPredicate> {
Self {
inner: service,
accept: AcceptEncoding::default(),
predicate: DefaultPredicate::default(),
quality: CompressionLevel::default(),
}
}
}
impl<S, P> Compression<S, P> {
define_inner_service_accessors!();
pub fn layer() -> CompressionLayer {
CompressionLayer::new()
}
#[cfg(feature = "compression-gzip")]
pub fn gzip(mut self, enable: bool) -> Self {
self.accept.set_gzip(enable);
self
}
#[cfg(feature = "compression-deflate")]
pub fn deflate(mut self, enable: bool) -> Self {
self.accept.set_deflate(enable);
self
}
#[cfg(feature = "compression-br")]
pub fn br(mut self, enable: bool) -> Self {
self.accept.set_br(enable);
self
}
#[cfg(feature = "compression-zstd")]
pub fn zstd(mut self, enable: bool) -> Self {
self.accept.set_zstd(enable);
self
}
pub fn quality(mut self, quality: CompressionLevel) -> Self {
self.quality = quality;
self
}
pub fn no_gzip(mut self) -> Self {
self.accept.set_gzip(false);
self
}
pub fn no_deflate(mut self) -> Self {
self.accept.set_deflate(false);
self
}
pub fn no_br(mut self) -> Self {
self.accept.set_br(false);
self
}
pub fn no_zstd(mut self) -> Self {
self.accept.set_zstd(false);
self
}
pub fn compress_when<C>(self, predicate: C) -> Compression<S, C>
where
C: Predicate,
{
Compression {
inner: self.inner,
accept: self.accept,
predicate,
quality: self.quality,
}
}
}
impl<ReqBody, ResBody, S, P> Service<Request<ReqBody>> for Compression<S, P>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
ResBody: Body,
P: Predicate,
{
type Response = Response<CompressionBody<ResBody>>;
type Error = S::Error;
#[allow(unreachable_code, unused_mut, unused_variables, unreachable_patterns)]
async fn call(&self, req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
let encoding = Encoding::from_headers(req.headers(), self.accept);
let res = self.inner.call(req).await?;
let should_compress = !res.headers().contains_key(header::CONTENT_ENCODING)
&& self.predicate.should_compress(&res);
let (mut parts, body) = res.into_parts();
let body = match (should_compress, encoding) {
(false, _) | (_, Encoding::Identity) => {
return Ok(Response::from_parts(
parts,
CompressionBody::new(BodyInner::identity(body)),
))
}
#[cfg(feature = "compression-gzip")]
(_, Encoding::Gzip) => {
CompressionBody::new(BodyInner::gzip(WrapBody::new(body, self.quality)))
}
#[cfg(feature = "compression-deflate")]
(_, Encoding::Deflate) => {
CompressionBody::new(BodyInner::deflate(WrapBody::new(body, self.quality)))
}
#[cfg(feature = "compression-br")]
(_, Encoding::Brotli) => {
CompressionBody::new(BodyInner::brotli(WrapBody::new(body, self.quality)))
}
#[cfg(feature = "compression-zstd")]
(_, Encoding::Zstd) => {
CompressionBody::new(BodyInner::zstd(WrapBody::new(body, self.quality)))
}
#[cfg(feature = "fs")]
(true, _) => {
return Ok(Response::from_parts(
parts,
CompressionBody::new(BodyInner::identity(body)),
));
}
};
parts.headers.remove(header::CONTENT_LENGTH);
parts
.headers
.insert(header::CONTENT_ENCODING, encoding.into_header_value());
let res = Response::from_parts(parts, body);
Ok(res)
}
}