use super::StreamCompression;
use crate::headers::encoding::AcceptEncoding;
use crate::layer::compression::Predicate;
use crate::layer::compression::predicate::DefaultStreamPredicate;
use crate::layer::util::compression::CompressionLevel;
use rama_core::Layer;
#[derive(Clone, Debug)]
pub struct StreamCompressionLayer<P = DefaultStreamPredicate> {
accept: AcceptEncoding,
predicate: P,
quality: CompressionLevel,
enforce_not_acceptable: bool,
}
impl<P: Default> Default for StreamCompressionLayer<P> {
fn default() -> Self {
Self {
accept: AcceptEncoding::default(),
predicate: P::default(),
quality: CompressionLevel::default(),
enforce_not_acceptable: true,
}
}
}
impl<S, P> Layer<S> for StreamCompressionLayer<P>
where
P: Predicate,
{
type Service = StreamCompression<S, P>;
fn layer(&self, inner: S) -> Self::Service {
StreamCompression {
inner,
accept: self.accept,
predicate: self.predicate.clone(),
quality: self.quality,
enforce_not_acceptable: self.enforce_not_acceptable,
}
}
fn into_layer(self, inner: S) -> Self::Service {
StreamCompression {
inner,
accept: self.accept,
predicate: self.predicate,
quality: self.quality,
enforce_not_acceptable: self.enforce_not_acceptable,
}
}
}
impl StreamCompressionLayer {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn with_compress_predicate<C>(self, predicate: C) -> StreamCompressionLayer<C>
where
C: Predicate,
{
StreamCompressionLayer {
accept: self.accept,
predicate,
quality: self.quality,
enforce_not_acceptable: self.enforce_not_acceptable,
}
}
}
impl<P> StreamCompressionLayer<P> {
rama_utils::macros::generate_set_and_with! {
pub fn gzip(mut self, enable: bool) -> Self {
self.accept.set_gzip(enable);
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn deflate(mut self, enable: bool) -> Self {
self.accept.set_deflate(enable);
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn br(mut self, enable: bool) -> Self {
self.accept.set_br(enable);
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn zstd(mut self, enable: bool) -> Self {
self.accept.set_zstd(enable);
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn quality(mut self, quality: CompressionLevel) -> Self {
self.quality = quality;
self
}
}
rama_utils::macros::generate_set_and_with! {
pub fn enforce_not_acceptable(mut self, enable: bool) -> Self {
self.enforce_not_acceptable = enable;
self
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::layer::compression::predicate::MirrorDecompressed;
use crate::layer::decompression::DecompressedFrom;
use crate::{Request, Response, body::util::BodyExt, header::ACCEPT_ENCODING};
use rama_core::Service;
use rama_core::extensions::ExtensionsRef;
use rama_core::service::service_fn;
use rama_core::stream::io::ReaderStream;
use rama_http_types::Body;
use std::convert::Infallible;
use tokio::fs::File;
async fn handle(_req: Request) -> Result<Response, Infallible> {
let file = File::open("Cargo.toml").await.expect("file missing");
let stream = ReaderStream::new(file);
let body = Body::from_stream(stream);
Ok(Response::new(body))
}
#[tokio::test]
async fn accept_encoding_configuration_works() -> Result<(), rama_core::error::BoxError> {
use std::io::Read;
fn decode<R: Read>(mut r: R) -> std::io::Result<Vec<u8>> {
let mut buf = Vec::new();
r.read_to_end(&mut buf)?;
Ok(buf)
}
let expected = tokio::fs::read("Cargo.toml").await?;
let deflate_only_layer = StreamCompressionLayer::new()
.with_quality(CompressionLevel::Best)
.with_br(false)
.with_gzip(false);
let service = deflate_only_layer.into_layer(service_fn(handle));
let request = Request::builder()
.header(ACCEPT_ENCODING, "gzip, deflate, br")
.body(Body::empty())?;
let response = service.serve(request).await?;
assert_eq!(response.headers()["content-encoding"], "deflate");
let deflate_body = response.into_body().collect().await?.to_bytes();
let decoded = decode(flate2::bufread::ZlibDecoder::new(&deflate_body[..]))?;
assert_eq!(decoded, expected);
let br_only_layer = StreamCompressionLayer::new()
.with_quality(CompressionLevel::Best)
.with_gzip(false)
.with_deflate(false);
let service = br_only_layer.into_layer(service_fn(handle));
let request = Request::builder()
.header(ACCEPT_ENCODING, "gzip, deflate, br")
.body(Body::empty())?;
let response = service.serve(request).await?;
assert_eq!(response.headers()["content-encoding"], "br");
let br_body = response.into_body().collect().await?.to_bytes();
let decoded = decode(brotli::Decompressor::new(&br_body[..], 4096))?;
assert_eq!(decoded, expected);
Ok(())
}
#[tokio::test]
async fn zstd_is_web_safe() -> Result<(), rama_core::error::BoxError> {
async fn zeroes(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
Ok(Response::new(Body::from(vec![0u8; 18_874_368])))
}
let zstd_layer = StreamCompressionLayer::new()
.with_quality(CompressionLevel::Best)
.with_br(false)
.with_deflate(false)
.with_gzip(false);
let service = zstd_layer.into_layer(service_fn(zeroes));
let request = Request::builder()
.header(ACCEPT_ENCODING, "zstd")
.body(Body::empty())?;
let response = service.serve(request).await?;
assert_eq!(response.headers()["content-encoding"], "zstd");
let body = response.into_body();
let bytes = body.collect().await?.to_bytes();
let mut dec = zstd::Decoder::new(&*bytes)?;
dec.window_log_max(23)?;
std::io::copy(&mut dec, &mut std::io::sink())?;
Ok(())
}
#[tokio::test]
async fn mirror_decompressed_prefers_original_encoding()
-> Result<(), rama_core::error::BoxError> {
let service = StreamCompressionLayer::new()
.with_compress_predicate(MirrorDecompressed::new())
.into_layer(service_fn(|_: Request<Body>| async {
let res = Response::new(Body::from("Hello, World! Hello, World! Hello, World!"));
res.extensions().insert(DecompressedFrom::Brotli);
Ok::<_, Infallible>(res)
}));
let request = Request::builder()
.header(ACCEPT_ENCODING, "gzip, br")
.body(Body::empty())?;
let response = service.serve(request).await?;
assert_eq!(response.headers()["content-encoding"], "br");
Ok(())
}
#[tokio::test]
async fn does_not_compress_head_response() {
use crate::header::CONTENT_ENCODING;
use rama_http_types::Method;
let service = StreamCompressionLayer::new().into_layer(service_fn(handle));
let req = Request::builder()
.method(Method::HEAD)
.header(ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap();
let res = service.serve(req).await.unwrap();
assert!(
!res.headers().contains_key(CONTENT_ENCODING),
"HEAD response must not carry Content-Encoding"
);
}
#[tokio::test]
async fn does_not_compress_connect_response() {
use crate::header::CONTENT_ENCODING;
use rama_http_types::Method;
let service = StreamCompressionLayer::new().into_layer(service_fn(handle));
let req = Request::builder()
.method(Method::CONNECT)
.header(ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap();
let res = service.serve(req).await.unwrap();
assert!(
!res.headers().contains_key(CONTENT_ENCODING),
"CONNECT response must not carry Content-Encoding"
);
}
#[tokio::test]
async fn does_not_compress_204_response() {
use crate::header::CONTENT_ENCODING;
let service =
StreamCompressionLayer::new().into_layer(service_fn(async |_: Request<Body>| {
Ok::<_, Infallible>(Response::builder().status(204).body(Body::empty()).unwrap())
}));
let req = Request::builder()
.header(ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap();
let res = service.serve(req).await.unwrap();
assert!(
!res.headers().contains_key(CONTENT_ENCODING),
"204 response must not carry Content-Encoding"
);
}
#[tokio::test]
async fn does_not_compress_304_response() {
use crate::header::CONTENT_ENCODING;
let service =
StreamCompressionLayer::new().into_layer(service_fn(async |_: Request<Body>| {
Ok::<_, Infallible>(Response::builder().status(304).body(Body::empty()).unwrap())
}));
let req = Request::builder()
.header(ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap();
let res = service.serve(req).await.unwrap();
assert!(
!res.headers().contains_key(CONTENT_ENCODING),
"304 response must not carry Content-Encoding"
);
}
#[tokio::test]
async fn does_not_compress_1xx_response() {
use crate::header::CONTENT_ENCODING;
let service =
StreamCompressionLayer::new().into_layer(service_fn(async |_: Request<Body>| {
Ok::<_, Infallible>(Response::builder().status(100).body(Body::empty()).unwrap())
}));
let req = Request::builder()
.header(ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap();
let res = service.serve(req).await.unwrap();
assert!(
!res.headers().contains_key(CONTENT_ENCODING),
"1xx response must not carry Content-Encoding"
);
}
#[tokio::test]
async fn does_not_compress_205_response() {
use crate::header::CONTENT_ENCODING;
let service =
StreamCompressionLayer::new().into_layer(service_fn(async |_: Request<Body>| {
Ok::<_, Infallible>(Response::builder().status(205).body(Body::empty()).unwrap())
}));
let req = Request::builder()
.header(ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap();
let res = service.serve(req).await.unwrap();
assert!(
!res.headers().contains_key(CONTENT_ENCODING),
"205 response must not carry Content-Encoding"
);
}
#[tokio::test]
async fn does_not_compress_range_response() {
use crate::header::{CONTENT_ENCODING, CONTENT_RANGE};
let service =
StreamCompressionLayer::new().into_layer(service_fn(async |_: Request<Body>| {
Ok::<_, Infallible>(
Response::builder()
.status(206)
.header(CONTENT_RANGE, "bytes 0-4/10")
.body(Body::from("hello"))
.unwrap(),
)
}));
let req = Request::builder()
.header(ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap();
let res = service.serve(req).await.unwrap();
assert!(
!res.headers().contains_key(CONTENT_ENCODING),
"range response must not carry Content-Encoding"
);
}
#[tokio::test]
async fn wildcard_q_zero_returns_406() {
use crate::StatusCode;
let service = StreamCompressionLayer::new().into_layer(service_fn(handle));
let req = Request::builder()
.header(ACCEPT_ENCODING, "*;q=0")
.body(Body::empty())
.unwrap();
let res = service.serve(req).await.unwrap();
assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
}
#[tokio::test]
async fn enforce_not_acceptable_opt_out_falls_back_to_identity() {
use crate::StatusCode;
use crate::header::CONTENT_ENCODING;
let service = StreamCompressionLayer::new()
.with_enforce_not_acceptable(false)
.into_layer(service_fn(handle));
let req = Request::builder()
.header(ACCEPT_ENCODING, "*;q=0")
.body(Body::empty())
.unwrap();
let res = service.serve(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert!(!res.headers().contains_key(CONTENT_ENCODING));
}
}