use std::task::{Context, Poll};
use http::{Request, Response};
use http_body::Body;
use tower::{Layer, Service};
use tower_http::decompression::{self, DecompressionBody, ResponseFuture};
use crate::config::RequestConfig;
#[derive(Clone)]
pub(crate) struct AcceptEncoding {
#[cfg(feature = "gzip")]
pub(crate) gzip: bool,
#[cfg(feature = "brotli")]
pub(crate) brotli: bool,
#[cfg(feature = "zstd")]
pub(crate) zstd: bool,
#[cfg(feature = "deflate")]
pub(crate) deflate: bool,
}
#[derive(Clone)]
pub struct DecompressionLayer {
accept: AcceptEncoding,
}
#[derive(Clone)]
pub struct Decompression<S>(Option<decompression::Decompression<S>>);
impl Default for AcceptEncoding {
fn default() -> AcceptEncoding {
AcceptEncoding {
#[cfg(feature = "gzip")]
gzip: true,
#[cfg(feature = "brotli")]
brotli: true,
#[cfg(feature = "zstd")]
zstd: true,
#[cfg(feature = "deflate")]
deflate: true,
}
}
}
impl_request_config_value!(AcceptEncoding);
impl DecompressionLayer {
#[inline(always)]
pub const fn new(accept: AcceptEncoding) -> Self {
Self { accept }
}
}
impl<S> Layer<S> for DecompressionLayer {
type Service = Decompression<S>;
#[inline(always)]
fn layer(&self, service: S) -> Self::Service {
Decompression(Some(Decompression::<S>::accept_in_place(
decompression::Decompression::new(service),
&self.accept,
)))
}
}
impl<S> Decompression<S> {
const BUG_MSG: &str = "[BUG] Decompression service not initialized; bug in setup";
fn accept_in_place(
mut decoder: decompression::Decompression<S>,
accept: &AcceptEncoding,
) -> decompression::Decompression<S> {
#[cfg(feature = "gzip")]
{
decoder = decoder.gzip(accept.gzip);
}
#[cfg(feature = "deflate")]
{
decoder = decoder.deflate(accept.deflate);
}
#[cfg(feature = "brotli")]
{
decoder = decoder.br(accept.brotli);
}
#[cfg(feature = "zstd")]
{
decoder = decoder.zstd(accept.zstd);
}
decoder
}
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Decompression<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
ReqBody: Body,
ResBody: Body,
{
type Response = Response<DecompressionBody<ResBody>>;
type Error = S::Error;
type Future = ResponseFuture<S::Future>;
#[inline(always)]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.0.as_mut().expect(Self::BUG_MSG).poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
if let Some(accept) = RequestConfig::<AcceptEncoding>::get(req.extensions()) {
if let Some(decoder) = self.0.take() {
self.0
.replace(Decompression::accept_in_place(decoder, accept));
}
debug_assert!(self.0.is_some());
}
self.0.as_mut().expect(Self::BUG_MSG).call(req)
}
}