pub mod predicate;
pub mod stream;
pub(crate) mod body;
mod layer;
mod pin_project_cfg;
mod service;
#[doc(inline)]
pub use self::{
body::CompressionBody,
layer::CompressionLayer,
predicate::{DefaultPredicate, MirrorDecompressed, Predicate, PreferredEncoding},
service::Compression,
};
#[doc(inline)]
pub use crate::layer::util::compression::CompressionLevel;
#[cfg(test)]
mod tests {
use super::*;
use crate::layer::compression::predicate::{MirrorDecompressed, PreferredEncoding, SizeAbove};
use crate::layer::decompression::DecompressedFrom;
use crate::header::{
ACCEPT_ENCODING, ACCEPT_RANGES, CONTENT_ENCODING, CONTENT_RANGE, CONTENT_TYPE, RANGE,
};
use crate::{HeaderMap, HeaderValue, Request, Response, StreamingBody, body::util::BodyExt};
use async_compression::tokio::write::{BrotliDecoder, BrotliEncoder};
use flate2::read::GzDecoder;
use rama_core::Service;
use rama_core::bytes::Bytes;
use rama_core::error::BoxError;
use rama_core::extensions::ExtensionsRef;
use rama_core::service::service_fn;
use rama_core::stream::io::StreamReader;
use rama_http_types::Body;
use std::convert::Infallible;
use std::io::Read;
use std::sync::{Arc, RwLock};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[derive(Clone)]
struct Always;
impl Predicate for Always {
fn should_compress<B>(&self, _: &mut rama_http_types::Response<B>) -> bool
where
B: StreamingBody,
{
true
}
}
#[tokio::test]
async fn gzip_works() {
let svc = service_fn(handle);
let svc = Compression::new(svc).with_compress_predicate(Always);
let req = Request::builder()
.header("accept-encoding", "gzip")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
let collected = res.into_body().collect().await.unwrap();
let compressed_data = collected.to_bytes();
let mut decoder = GzDecoder::new(&compressed_data[..]);
let mut decompressed = String::new();
decoder.read_to_string(&mut decompressed).unwrap();
assert_eq!(decompressed, "Hello, World!");
}
#[tokio::test]
async fn x_gzip_works() {
let svc = service_fn(handle);
let svc = Compression::new(svc).with_compress_predicate(Always);
let req = Request::builder()
.header("accept-encoding", "x-gzip")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert_eq!(
res.headers()
.get_all("content-encoding")
.iter()
.collect::<Vec<&HeaderValue>>(),
vec!(HeaderValue::from_static("gzip"))
);
let collected = res.into_body().collect().await.unwrap();
let compressed_data = collected.to_bytes();
let mut decoder = GzDecoder::new(&compressed_data[..]);
let mut decompressed = String::new();
decoder.read_to_string(&mut decompressed).unwrap();
assert_eq!(decompressed, "Hello, World!");
}
#[tokio::test]
async fn zstd_works() {
let svc = service_fn(handle);
let svc = Compression::new(svc).with_compress_predicate(Always);
let req = Request::builder()
.header("accept-encoding", "zstd")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
let body = res.into_body();
let compressed_data = body.collect().await.unwrap().to_bytes();
let decompressed = zstd::stream::decode_all(std::io::Cursor::new(compressed_data)).unwrap();
let decompressed = String::from_utf8(decompressed).unwrap();
assert_eq!(decompressed, "Hello, World!");
}
#[tokio::test]
async fn predicate_only_compresses_previously_decompressed_responses() {
let svc = service_fn(async |_| {
let res = Response::new(Body::from("Hello, World!"));
res.extensions().insert(DecompressedFrom::Gzip);
Ok::<_, Infallible>(res)
});
let svc = Compression::new(svc).with_compress_predicate(MirrorDecompressed::new());
let req = Request::builder()
.header("accept-encoding", "gzip")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert_eq!(res.headers()[CONTENT_ENCODING], "gzip");
let collected = res.into_body().collect().await.unwrap();
let compressed_data = collected.to_bytes();
let mut decoder = GzDecoder::new(&compressed_data[..]);
let mut decompressed = String::new();
decoder.read_to_string(&mut decompressed).unwrap();
assert_eq!(decompressed, "Hello, World!");
}
#[tokio::test]
async fn predicate_skips_responses_that_were_not_decompressed() {
let svc =
service_fn(async |_| Ok::<_, Infallible>(Response::new(Body::from("Hello, World!"))));
let svc = Compression::new(svc).with_compress_predicate(MirrorDecompressed::new());
let req = Request::builder()
.header("accept-encoding", "gzip")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert!(!res.headers().contains_key(CONTENT_ENCODING));
let collected = res.into_body().collect().await.unwrap();
assert_eq!(collected.to_bytes().as_ref(), b"Hello, World!");
}
#[tokio::test]
async fn mirror_decompressed_sets_preferred_encoding() {
let mut res = Response::new(Body::from("Hello, World!"));
res.extensions().insert(DecompressedFrom::Brotli);
let predicate = MirrorDecompressed::new();
assert!(predicate.should_compress(&mut res));
assert_eq!(
res.extensions().get_ref::<PreferredEncoding>(),
Some(&PreferredEncoding::Brotli)
);
}
#[tokio::test]
async fn respect_content_encoding_overrides_predicate_preference() {
let svc = service_fn(async |_| {
let mut res = Response::new(Body::from("Hello, World! Hello, World! Hello, World!"));
res.headers_mut()
.insert(CONTENT_ENCODING, HeaderValue::from_static("gzip"));
res.extensions().insert(PreferredEncoding::Brotli);
Ok::<_, Infallible>(res)
});
let svc = Compression::new(svc)
.with_respect_content_encoding_if_possible()
.with_compress_predicate(Always);
let req = Request::builder()
.header("accept-encoding", "gzip, br")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert_eq!(res.headers()[CONTENT_ENCODING], "gzip");
}
#[tokio::test]
async fn no_recompress() {
const DATA: &str = "Hello, World! I'm already compressed with br!";
let svc = service_fn(async |_| {
let buf = {
let mut buf = Vec::new();
let mut enc = BrotliEncoder::new(&mut buf);
enc.write_all(DATA.as_bytes()).await?;
enc.flush().await?;
buf
};
let resp = Response::builder()
.header("content-encoding", "br")
.body(Body::from(buf))
.unwrap();
Ok::<_, std::io::Error>(resp)
});
let svc = Compression::new(svc);
let req = Request::builder()
.header("accept-encoding", "gzip")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert_eq!(
res.headers()
.get("content-encoding")
.and_then(|h| h.to_str().ok())
.unwrap_or_default(),
"br",
);
let body = res.into_body();
let data = body.collect().await.unwrap().to_bytes();
let data = {
let mut output_buf = Vec::new();
let mut decoder = BrotliDecoder::new(&mut output_buf);
decoder
.write_all(&data)
.await
.expect("couldn't brotli-decode");
decoder.flush().await.expect("couldn't flush");
output_buf
};
assert_eq!(data, DATA.as_bytes());
}
async fn handle(_req: Request) -> Result<Response, Infallible> {
let body = Body::from("Hello, World!");
Ok(Response::builder().body(body).unwrap())
}
#[tokio::test]
async fn will_not_compress_if_filtered_out() {
use predicate::Predicate;
const DATA: &str = "Hello world uncompressed";
let svc_fn = service_fn(async |_| {
let resp = Response::builder()
.body(Body::from(DATA.as_bytes()))
.unwrap();
Ok::<_, std::io::Error>(resp)
});
#[derive(Default, Clone)]
struct EveryOtherResponse(Arc<RwLock<u64>>);
impl Predicate for EveryOtherResponse {
fn should_compress<B>(&self, _: &mut rama_http_types::Response<B>) -> bool
where
B: StreamingBody,
{
let mut guard = self.0.write().unwrap();
let should_compress = !(*guard).is_multiple_of(2);
*guard += 1;
should_compress
}
}
let svc = Compression::new(svc_fn).with_compress_predicate(EveryOtherResponse::default());
let req = Request::builder()
.header("accept-encoding", "br")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
let body = res.into_body();
let data = body.collect().await.unwrap().to_bytes();
let still_uncompressed = String::from_utf8(data.to_vec()).unwrap();
assert_eq!(DATA, &still_uncompressed);
let req = Request::builder()
.header("accept-encoding", "br")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
let body = res.into_body();
let data = body.collect().await.unwrap().to_bytes();
String::from_utf8(data.to_vec()).unwrap_err();
}
#[tokio::test]
async fn doesnt_compress_images() {
async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
let mut res = Response::new(Body::from(
"a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
));
res.headers_mut()
.insert(CONTENT_TYPE, "image/png".parse().unwrap());
Ok(res)
}
let svc = Compression::new(service_fn(handle));
let res = svc
.serve(
Request::builder()
.header(ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert!(res.headers().get(CONTENT_ENCODING).is_none());
}
#[tokio::test]
async fn does_compress_svg() {
async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
let mut res = Response::new(Body::from(
"a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
));
res.headers_mut()
.insert(CONTENT_TYPE, "image/svg+xml".parse().unwrap());
Ok(res)
}
let svc = Compression::new(service_fn(handle));
let res = svc
.serve(
Request::builder()
.header(ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.headers()[CONTENT_ENCODING], "gzip");
}
#[tokio::test]
async fn does_compress_grpc_web() {
async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
let mut res = Response::new(Body::from(
"a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
));
res.headers_mut()
.insert(CONTENT_TYPE, "application/grpc-web+proto".parse().unwrap());
Ok(res)
}
let svc = Compression::new(service_fn(handle));
let res = svc
.serve(
Request::builder()
.header(ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.headers()[CONTENT_ENCODING], "gzip");
}
#[tokio::test]
async fn compress_with_quality() {
const DATA: &str = "Check compression quality level! Check compression quality level! Check compression quality level!";
let level = CompressionLevel::Best;
let svc = service_fn(async |_| {
let resp = Response::builder()
.body(Body::from(DATA.as_bytes()))
.unwrap();
Ok::<_, std::io::Error>(resp)
});
let svc = Compression::new(svc).with_quality(level);
let req = Request::builder()
.header("accept-encoding", "br")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
let body = res.into_body();
let compressed_data = body.collect().await.unwrap().to_bytes();
let compressed_with_level = {
use async_compression::tokio::bufread::BrotliEncoder;
let stream = Box::pin(rama_core::futures::stream::once(async {
Ok::<_, std::io::Error>(DATA.as_bytes())
}));
let reader = StreamReader::new(stream);
let mut enc = BrotliEncoder::with_quality(reader, level.into_async_compression());
let mut buf = Vec::new();
enc.read_to_end(&mut buf).await.unwrap();
buf
};
assert_eq!(
compressed_data,
compressed_with_level.as_slice(),
"Compression level is not respected"
);
}
#[tokio::test]
async fn should_not_compress_ranges() {
let svc = service_fn(async |_| {
let mut res = Response::new(Body::from("Hello"));
let headers = res.headers_mut();
headers.insert(ACCEPT_RANGES, "bytes".parse().unwrap());
headers.insert(CONTENT_RANGE, "bytes 0-4/*".parse().unwrap());
Ok::<_, std::io::Error>(res)
});
let svc = Compression::new(svc).with_compress_predicate(Always);
let req = Request::builder()
.header(ACCEPT_ENCODING, "gzip")
.header(RANGE, "bytes=0-4")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
let headers = res.headers().clone();
let collected = res.into_body().collect().await.unwrap().to_bytes();
assert_eq!(headers[ACCEPT_RANGES], "bytes");
assert!(!headers.contains_key(CONTENT_ENCODING));
assert_eq!(collected, "Hello");
}
#[tokio::test]
async fn should_strip_accept_ranges_header_when_compressing() {
let svc = service_fn(async |_| {
let mut res = Response::new(Body::from("Hello, World!"));
res.headers_mut()
.insert(ACCEPT_RANGES, "bytes".parse().unwrap());
Ok::<_, std::io::Error>(res)
});
let svc = Compression::new(svc).with_compress_predicate(Always);
let req = Request::builder()
.header(ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
let headers = res.headers().clone();
let collected = res.into_body().collect().await.unwrap();
let compressed_data = collected.to_bytes();
let mut decoder = GzDecoder::new(&compressed_data[..]);
let mut decompressed = String::new();
decoder.read_to_string(&mut decompressed).unwrap();
assert!(!headers.contains_key(ACCEPT_RANGES));
assert_eq!(headers[CONTENT_ENCODING], "gzip");
assert_eq!(decompressed, "Hello, World!");
}
#[tokio::test]
async fn trailers_with_empty_body() {
let svc = service_fn(|_req: Request<Body>| async {
let mut trailers = HeaderMap::new();
trailers.insert("grpc-status", "0".parse().unwrap());
trailers.insert("grpc-message", "OK".parse().unwrap());
let body = Body::empty().with_trailer_headers(trailers);
Ok::<_, Infallible>(Response::builder().body(body).unwrap())
});
let svc = Compression::new(svc).with_compress_predicate(Always);
let req = Request::builder()
.header("accept-encoding", "gzip")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
let collected = res.into_body().collect().await.unwrap();
let trailers = collected.trailers().cloned().unwrap();
assert_eq!(trailers["grpc-status"], "0");
assert_eq!(trailers["grpc-message"], "OK");
}
#[tokio::test]
async fn trailers_with_streamed_body() {
let svc = service_fn(|_req: Request<Body>| async {
let stream = rama_core::stream::iter(vec![
Ok::<_, BoxError>(Bytes::from("chunk1")),
Ok(Bytes::from("chunk2")),
Ok(Bytes::from("chunk3")),
]);
let mut trailers = HeaderMap::new();
trailers.insert("grpc-status", "0".parse().unwrap());
let body = Body::from_stream(stream).with_trailer_headers(trailers);
Ok::<_, Infallible>(Response::builder().body(body).unwrap())
});
let svc = Compression::new(svc).with_compress_predicate(Always);
let req = Request::builder()
.header("accept-encoding", "gzip")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
let collected = res.into_body().collect().await.unwrap();
let trailers = collected.trailers().cloned().unwrap();
let compressed_data = collected.to_bytes();
let mut decoder = GzDecoder::new(&compressed_data[..]);
let mut decompressed = String::new();
decoder.read_to_string(&mut decompressed).unwrap();
assert_eq!(decompressed, "chunk1chunk2chunk3");
assert_eq!(trailers["grpc-status"], "0");
}
#[tokio::test]
async fn trailers_with_grpc_web_content_type() {
let svc = service_fn(|_req: Request<Body>| async {
let mut trailers = HeaderMap::new();
trailers.insert("grpc-status", "0".parse().unwrap());
let body = Body::from("a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize))
.with_trailer_headers(trailers);
let mut res = Response::new(body);
res.headers_mut()
.insert(CONTENT_TYPE, "application/grpc-web+proto".parse().unwrap());
Ok::<_, Infallible>(res)
});
let svc = Compression::new(svc).with_compress_predicate(Always);
let req = Request::builder()
.header("accept-encoding", "gzip")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
let collected = res.into_body().collect().await.unwrap();
let trailers = collected.trailers().cloned().unwrap();
assert_eq!(trailers["grpc-status"], "0");
}
#[tokio::test]
async fn size_hint_identity() {
const MSG: &str = "Hello, world!";
let svc = service_fn(async |_| Ok::<_, std::io::Error>(Response::new(Body::from(MSG))));
let svc = Compression::new(svc);
let req = Request::new(Body::empty());
let res = svc.serve(req).await.unwrap();
let body = res.into_body();
assert_eq!(body.size_hint().exact().unwrap(), MSG.len() as u64);
}
#[tokio::test]
async fn does_not_compress_head_response() {
use rama_http_types::Method;
let svc = Compression::new(service_fn(handle)).with_compress_predicate(Always);
let req = Request::builder()
.method(Method::HEAD)
.header(ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert!(
!res.headers().contains_key(CONTENT_ENCODING),
"HEAD response must not carry Content-Encoding"
);
}
#[tokio::test]
async fn does_not_compress_connect_response() {
use rama_http_types::Method;
let svc = Compression::new(service_fn(handle)).with_compress_predicate(Always);
let req = Request::builder()
.method(Method::CONNECT)
.header(ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert!(
!res.headers().contains_key(CONTENT_ENCODING),
"CONNECT response must not carry Content-Encoding"
);
}
#[tokio::test]
async fn does_not_compress_204_response() {
let svc = Compression::new(service_fn(async |_| {
Ok::<_, Infallible>(Response::builder().status(204).body(Body::empty()).unwrap())
}))
.with_compress_predicate(Always);
let req = Request::builder()
.header(ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert!(
!res.headers().contains_key(CONTENT_ENCODING),
"204 response must not carry Content-Encoding"
);
}
#[tokio::test]
async fn does_not_compress_304_response() {
let svc = Compression::new(service_fn(async |_| {
Ok::<_, Infallible>(Response::builder().status(304).body(Body::empty()).unwrap())
}))
.with_compress_predicate(Always);
let req = Request::builder()
.header(ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert!(
!res.headers().contains_key(CONTENT_ENCODING),
"304 response must not carry Content-Encoding"
);
}
#[tokio::test]
async fn does_not_compress_1xx_response() {
let svc = Compression::new(service_fn(async |_| {
Ok::<_, Infallible>(Response::builder().status(100).body(Body::empty()).unwrap())
}))
.with_compress_predicate(Always);
let req = Request::builder()
.header(ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert!(
!res.headers().contains_key(CONTENT_ENCODING),
"1xx response must not carry Content-Encoding"
);
}
#[tokio::test]
async fn does_not_compress_205_response() {
let svc = Compression::new(service_fn(async |_| {
Ok::<_, Infallible>(Response::builder().status(205).body(Body::empty()).unwrap())
}))
.with_compress_predicate(Always);
let req = Request::builder()
.header(ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert!(
!res.headers().contains_key(CONTENT_ENCODING),
"205 response must not carry Content-Encoding"
);
}
#[tokio::test]
async fn wildcard_q_zero_returns_406() {
let svc = Compression::new(service_fn(handle)).with_compress_predicate(Always);
let req = Request::builder()
.header(ACCEPT_ENCODING, "*;q=0")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert_eq!(res.status(), crate::StatusCode::NOT_ACCEPTABLE);
assert!(
res.headers()
.get_all(crate::header::VARY)
.iter()
.any(|v| v.to_str().unwrap().contains("accept-encoding"))
);
}
#[tokio::test]
async fn wildcard_q_zero_with_gzip_picks_gzip() {
let svc = Compression::new(service_fn(handle)).with_compress_predicate(Always);
let req = Request::builder()
.header(ACCEPT_ENCODING, "*;q=0,gzip")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert_eq!(res.status(), crate::StatusCode::OK);
assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "gzip");
}
#[tokio::test]
async fn wildcard_alone_compresses() {
let svc = Compression::new(service_fn(handle)).with_compress_predicate(Always);
let req = Request::builder()
.header(ACCEPT_ENCODING, "*")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert_eq!(res.status(), crate::StatusCode::OK);
assert!(res.headers().contains_key(CONTENT_ENCODING));
}
#[tokio::test]
async fn identity_q_zero_alone_returns_406() {
let svc = Compression::new(service_fn(handle)).with_compress_predicate(Always);
let req = Request::builder()
.header(ACCEPT_ENCODING, "identity;q=0")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert_eq!(res.status(), crate::StatusCode::NOT_ACCEPTABLE);
}
#[tokio::test]
async fn identity_q_zero_with_gzip_picks_gzip() {
let svc = Compression::new(service_fn(handle)).with_compress_predicate(Always);
let req = Request::builder()
.header(ACCEPT_ENCODING, "identity;q=0,gzip")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert_eq!(res.status(), crate::StatusCode::OK);
assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "gzip");
}
#[tokio::test]
async fn enforce_not_acceptable_opt_out_falls_back_to_identity() {
let svc = Compression::new(service_fn(handle))
.with_compress_predicate(Always)
.with_enforce_not_acceptable(false);
let req = Request::builder()
.header(ACCEPT_ENCODING, "*;q=0")
.body(Body::empty())
.unwrap();
let res = svc.serve(req).await.unwrap();
assert_eq!(res.status(), crate::StatusCode::OK);
assert!(!res.headers().contains_key(CONTENT_ENCODING));
}
}