use super::predicate::DefaultPredicate;
use super::{Compression, Predicate};
use crate::headers::encoding::AcceptEncoding;
use crate::layer::util::compression::CompressionLevel;
use rama_core::Layer;
#[derive(Clone, Debug)]
pub struct CompressionLayer<P = DefaultPredicate> {
accept: AcceptEncoding,
predicate: P,
respect_content_encoding_if_possible: bool,
quality: CompressionLevel,
enforce_not_acceptable: bool,
}
impl<P: Default> Default for CompressionLayer<P> {
fn default() -> Self {
Self {
accept: AcceptEncoding::default(),
predicate: P::default(),
respect_content_encoding_if_possible: false,
quality: CompressionLevel::default(),
enforce_not_acceptable: true,
}
}
}
impl<S, P> Layer<S> for CompressionLayer<P>
where
P: Predicate,
{
type Service = Compression<S, P>;
fn layer(&self, inner: S) -> Self::Service {
Compression {
inner,
accept: self.accept,
predicate: self.predicate.clone(),
respect_content_encoding_if_possible: self.respect_content_encoding_if_possible,
quality: self.quality,
enforce_not_acceptable: self.enforce_not_acceptable,
}
}
fn into_layer(self, inner: S) -> Self::Service {
Compression {
inner,
accept: self.accept,
predicate: self.predicate,
respect_content_encoding_if_possible: self.respect_content_encoding_if_possible,
quality: self.quality,
enforce_not_acceptable: self.enforce_not_acceptable,
}
}
}
impl CompressionLayer {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn with_compress_predicate<C>(self, predicate: C) -> CompressionLayer<C>
where
C: Predicate,
{
CompressionLayer {
accept: self.accept,
predicate,
respect_content_encoding_if_possible: self.respect_content_encoding_if_possible,
quality: self.quality,
enforce_not_acceptable: self.enforce_not_acceptable,
}
}
}
impl<P> CompressionLayer<P> {
rama_utils::macros::generate_set_and_with! {
pub fn gzip(mut self, enable: bool) -> Self {
self.accept.set_gzip(enable);
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn deflate(mut self, enable: bool) -> Self {
self.accept.set_deflate(enable);
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn br(mut self, enable: bool) -> Self {
self.accept.set_br(enable);
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn zstd(mut self, enable: bool) -> Self {
self.accept.set_zstd(enable);
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn quality(mut self, quality: CompressionLevel) -> Self {
self.quality = quality;
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn respect_content_encoding_if_possible(mut self) -> Self {
self.respect_content_encoding_if_possible = true;
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn enforce_not_acceptable(mut self, enable: bool) -> Self {
self.enforce_not_acceptable = enable;
self
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Request, Response, body::util::BodyExt, header::ACCEPT_ENCODING};
use rama_core::Service;
use rama_core::service::service_fn;
use rama_core::stream::io::ReaderStream;
use rama_http_types::Body;
use std::convert::Infallible;
use tokio::fs::File;
async fn handle(_req: Request) -> Result<Response, Infallible> {
let file = File::open("Cargo.toml").await.expect("file missing");
let stream = ReaderStream::new(file);
let body = Body::from_stream(stream);
Ok(Response::new(body))
}
#[tokio::test]
async fn accept_encoding_configuration_works() -> Result<(), rama_core::error::BoxError> {
use std::io::Read;
fn decode<R: Read>(mut r: R) -> std::io::Result<Vec<u8>> {
let mut buf = Vec::new();
r.read_to_end(&mut buf)?;
Ok(buf)
}
let expected = tokio::fs::read("Cargo.toml").await?;
let deflate_only_layer = CompressionLayer::new()
.with_quality(CompressionLevel::Best)
.with_br(false)
.with_gzip(false);
let service = deflate_only_layer.into_layer(service_fn(handle));
let request = Request::builder()
.header(ACCEPT_ENCODING, "gzip, deflate, br")
.body(Body::empty())?;
let response = service.serve(request).await?;
assert_eq!(response.headers()["content-encoding"], "deflate");
let deflate_body = response.into_body().collect().await?.to_bytes();
let decoded = decode(flate2::bufread::ZlibDecoder::new(&deflate_body[..]))?;
assert_eq!(decoded, expected);
let br_only_layer = CompressionLayer::new()
.with_quality(CompressionLevel::Best)
.with_gzip(false)
.with_deflate(false);
let service = br_only_layer.into_layer(service_fn(handle));
let request = Request::builder()
.header(ACCEPT_ENCODING, "gzip, deflate, br")
.body(Body::empty())?;
let response = service.serve(request).await?;
assert_eq!(response.headers()["content-encoding"], "br");
let br_body = response.into_body().collect().await?.to_bytes();
let decoded = decode(brotli::Decompressor::new(&br_body[..], 4096))?;
assert_eq!(decoded, expected);
Ok(())
}
#[tokio::test]
async fn zstd_is_web_safe() -> Result<(), rama_core::error::BoxError> {
async fn zeroes(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
Ok(Response::new(Body::from(vec![0u8; 18_874_368])))
}
let zstd_layer = CompressionLayer::new()
.with_quality(CompressionLevel::Best)
.with_br(false)
.with_deflate(false)
.with_gzip(false);
let service = zstd_layer.into_layer(service_fn(zeroes));
let request = Request::builder()
.header(ACCEPT_ENCODING, "zstd")
.body(Body::empty())?;
let response = service.serve(request).await?;
assert_eq!(response.headers()["content-encoding"], "zstd");
let body = response.into_body();
let bytes = body.collect().await?.to_bytes();
let mut dec = zstd::Decoder::new(&*bytes)?;
dec.window_log_max(23)?;
std::io::copy(&mut dec, &mut std::io::sink())?;
Ok(())
}
}