#[cfg(any(feature = "compression", feature = "compression-brotli"))]
use async_compression::tokio::bufread::BrotliEncoder;
#[cfg(any(feature = "compression", feature = "compression-deflate"))]
use async_compression::tokio::bufread::DeflateEncoder;
#[cfg(any(feature = "compression", feature = "compression-gzip"))]
use async_compression::tokio::bufread::GzipEncoder;
#[cfg(any(feature = "compression", feature = "compression-zstd"))]
use async_compression::tokio::bufread::ZstdEncoder;
use bytes::Bytes;
use futures_util::Stream;
use headers::{ContentType, HeaderMap, HeaderMapExt, HeaderValue};
use hyper::{
Body, Method, Request, Response, StatusCode,
header::{CONTENT_ENCODING, CONTENT_LENGTH},
};
use mime_guess::Mime;
use pin_project::pin_project;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio_util::io::{ReaderStream, StreamReader};
use crate::{
Error, Result, error_page,
handler::RequestHandlerOpts,
headers_ext::{AcceptEncoding, ContentCoding},
http_ext::MethodExt,
mime_ext::MimeExt,
settings::CompressionLevel,
};
const MIN_COMPRESS_SIZE: usize = 200;
const AVAILABLE_ENCODINGS: &[ContentCoding] = &[
#[cfg(any(feature = "compression", feature = "compression-deflate"))]
ContentCoding::DEFLATE,
#[cfg(any(feature = "compression", feature = "compression-gzip"))]
ContentCoding::GZIP,
#[cfg(any(feature = "compression", feature = "compression-brotli"))]
ContentCoding::BROTLI,
#[cfg(any(feature = "compression", feature = "compression-zstd"))]
ContentCoding::ZSTD,
];
pub fn init(enabled: bool, level: CompressionLevel, handler_opts: &mut RequestHandlerOpts) {
handler_opts.compression = enabled;
handler_opts.compression_level = level;
const FORMATS: &[&str] = &[
#[cfg(any(feature = "compression", feature = "compression-deflate"))]
"deflate",
#[cfg(any(feature = "compression", feature = "compression-gzip"))]
"gzip",
#[cfg(any(feature = "compression", feature = "compression-brotli"))]
"brotli",
#[cfg(any(feature = "compression", feature = "compression-zstd"))]
"zstd",
];
tracing::info!(
"auto compression: enabled={enabled}, formats={}, compression level={level:?}",
FORMATS.join(",")
);
}
pub(crate) fn post_process<T>(
opts: &RequestHandlerOpts,
req: &Request<T>,
mut resp: Response<Body>,
) -> Result<Response<Body>, Error> {
if !opts.compression {
return Ok(resp);
}
let is_precompressed = resp.headers().get(CONTENT_ENCODING).is_some();
if is_precompressed {
return Ok(resp);
}
let enc = HeaderValue::from_name(hyper::header::ACCEPT_ENCODING);
let value = resp.headers().get(hyper::header::VARY).map_or(enc, |h| {
let mut a = h.to_str().unwrap_or_default().to_owned();
let b = hyper::header::ACCEPT_ENCODING.as_str();
if !a.contains(b) {
if !a.is_empty() {
a.push(',');
}
a.push_str(b);
}
HeaderValue::from_str(a.as_str()).unwrap()
});
resp.headers_mut().insert(hyper::header::VARY, value);
match auto(req.method(), req.headers(), opts.compression_level, resp) {
Ok(resp) => Ok(resp),
Err(err) => {
tracing::error!("error during body compression: {:?}", err);
error_page::error_response(
req.uri(),
req.method(),
&StatusCode::INTERNAL_SERVER_ERROR,
&opts.page404,
&opts.page50x,
)
}
}
}
pub fn auto(
method: &Method,
headers: &HeaderMap<HeaderValue>,
level: CompressionLevel,
resp: Response<Body>,
) -> Result<Response<Body>> {
if method.is_head() || method.is_options() {
return Ok(resp);
}
if let Some(encoding) = get_preferred_encoding(headers) {
tracing::trace!(
"preferred encoding selected from the accept-encoding header: {:?}",
encoding
);
if let Some(content_type) = resp.headers().typed_get::<ContentType>()
&& !Mime::from(content_type).is_compressible()
{
return Ok(resp);
}
if let Some(content_length) = resp
.headers()
.get(CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<usize>().ok())
&& content_length < MIN_COMPRESS_SIZE
{
tracing::trace!(
"skipping compression: content-length ({content_length}) below minimum ({MIN_COMPRESS_SIZE})",
);
return Ok(resp);
}
#[cfg(any(feature = "compression", feature = "compression-gzip"))]
if encoding == ContentCoding::GZIP {
let (head, body) = resp.into_parts();
return Ok(gzip(head, body.into(), level));
}
#[cfg(any(feature = "compression", feature = "compression-deflate"))]
if encoding == ContentCoding::DEFLATE {
let (head, body) = resp.into_parts();
return Ok(deflate(head, body.into(), level));
}
#[cfg(any(feature = "compression", feature = "compression-brotli"))]
if encoding == ContentCoding::BROTLI {
let (head, body) = resp.into_parts();
return Ok(brotli(head, body.into(), level));
}
#[cfg(any(feature = "compression", feature = "compression-zstd"))]
if encoding == ContentCoding::ZSTD {
let (head, body) = resp.into_parts();
return Ok(zstd(head, body.into(), level));
}
tracing::trace!(
"no compression feature matched the preferred encoding, probably not enabled or unsupported"
);
}
Ok(resp)
}
#[cfg(any(feature = "compression", feature = "compression-gzip"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "compression", feature = "compression-gzip")))
)]
pub fn gzip(
mut head: http::response::Parts,
body: CompressableBody<Body, hyper::Error>,
level: CompressionLevel,
) -> Response<Body> {
const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
tracing::trace!("compressing response body on the fly using GZIP");
let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
let body = Body::wrap_stream(ReaderStream::new(GzipEncoder::with_quality(
StreamReader::new(body),
level,
)));
let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::GZIP);
head.headers.remove(CONTENT_LENGTH);
head.headers.insert(CONTENT_ENCODING, header);
Response::from_parts(head, body)
}
#[cfg(any(feature = "compression", feature = "compression-deflate"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "compression", feature = "compression-deflate")))
)]
pub fn deflate(
mut head: http::response::Parts,
body: CompressableBody<Body, hyper::Error>,
level: CompressionLevel,
) -> Response<Body> {
const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
tracing::trace!("compressing response body on the fly using DEFLATE");
let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
let body = Body::wrap_stream(ReaderStream::new(DeflateEncoder::with_quality(
StreamReader::new(body),
level,
)));
let header = create_encoding_header(
head.headers.remove(CONTENT_ENCODING),
ContentCoding::DEFLATE,
);
head.headers.remove(CONTENT_LENGTH);
head.headers.insert(CONTENT_ENCODING, header);
Response::from_parts(head, body)
}
#[cfg(any(feature = "compression", feature = "compression-brotli"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "compression", feature = "compression-brotli")))
)]
pub fn brotli(
mut head: http::response::Parts,
body: CompressableBody<Body, hyper::Error>,
level: CompressionLevel,
) -> Response<Body> {
const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
tracing::trace!("compressing response body on the fly using BROTLI");
let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
let body = Body::wrap_stream(ReaderStream::new(BrotliEncoder::with_quality(
StreamReader::new(body),
level,
)));
let header =
create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::BROTLI);
head.headers.remove(CONTENT_LENGTH);
head.headers.insert(CONTENT_ENCODING, header);
Response::from_parts(head, body)
}
#[cfg(any(feature = "compression", feature = "compression-zstd"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "compression", feature = "compression-zstd")))
)]
pub fn zstd(
mut head: http::response::Parts,
body: CompressableBody<Body, hyper::Error>,
level: CompressionLevel,
) -> Response<Body> {
const DEFAULT_COMPRESSION_LEVEL: i32 = 3;
tracing::trace!("compressing response body on the fly using ZSTD");
let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
let body = Body::wrap_stream(ReaderStream::new(ZstdEncoder::with_quality(
StreamReader::new(body),
level,
)));
let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::ZSTD);
head.headers.remove(CONTENT_LENGTH);
head.headers.insert(CONTENT_ENCODING, header);
Response::from_parts(head, body)
}
pub fn create_encoding_header(existing: Option<HeaderValue>, coding: ContentCoding) -> HeaderValue {
if let Some(val) = existing
&& let Ok(str_val) = val.to_str()
{
return HeaderValue::from_str(&[str_val, ", ", coding.as_str()].concat())
.unwrap_or_else(|_| coding.into());
}
coding.into()
}
#[inline(always)]
pub fn get_preferred_encoding(headers: &HeaderMap<HeaderValue>) -> Option<ContentCoding> {
if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
for encoding in accept_encoding.sorted_encodings() {
if AVAILABLE_ENCODINGS.contains(&encoding) {
return Some(encoding);
}
}
}
None
}
#[inline(always)]
pub fn get_encodings(headers: &HeaderMap<HeaderValue>) -> Vec<ContentCoding> {
if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
return accept_encoding
.sorted_encodings()
.filter(|encoding| AVAILABLE_ENCODINGS.contains(encoding))
.collect::<Vec<_>>();
}
vec![]
}
#[pin_project]
#[derive(Debug)]
pub struct CompressableBody<S, E>
where
S: Stream<Item = Result<Bytes, E>>,
E: std::error::Error,
{
#[pin]
body: S,
}
impl<S, E> Stream for CompressableBody<S, E>
where
S: Stream<Item = Result<Bytes, E>>,
E: std::error::Error,
{
type Item = std::io::Result<Bytes>;
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use std::io::{Error, ErrorKind};
let pin = self.project();
S::poll_next(pin.body, ctx).map_err(|_| Error::from(ErrorKind::InvalidData))
}
}
impl From<Body> for CompressableBody<Body, hyper::Error> {
#[inline(always)]
fn from(body: Body) -> Self {
CompressableBody { body }
}
}
#[cfg(test)]
#[cfg(any(feature = "compression", feature = "compression-gzip"))]
mod tests {
use super::*;
use http::header::{ACCEPT_ENCODING, CONTENT_TYPE};
fn text_response_with_size(size: usize) -> Response<Body> {
let mut resp = Response::new(Body::from(vec![b'x'; size]));
resp.headers_mut()
.insert(CONTENT_TYPE, "text/html".parse().unwrap());
resp.headers_mut()
.insert(CONTENT_LENGTH, size.to_string().parse().unwrap());
resp
}
fn text_response_without_length() -> Response<Body> {
let mut resp = Response::new(Body::from("hello world"));
resp.headers_mut()
.insert(CONTENT_TYPE, "text/html".parse().unwrap());
resp
}
fn accept_gzip_headers() -> HeaderMap<HeaderValue> {
let mut headers = HeaderMap::new();
headers.insert(ACCEPT_ENCODING, "gzip".parse().unwrap());
headers
}
#[test]
fn small_response_below_threshold_is_not_compressed() {
let resp = text_response_with_size(MIN_COMPRESS_SIZE - 1);
let headers = accept_gzip_headers();
let result = auto(&Method::GET, &headers, CompressionLevel::Default, resp).unwrap();
assert!(
result.headers().get(CONTENT_ENCODING).is_none(),
"responses below {MIN_COMPRESS_SIZE} bytes must not be compressed"
);
}
#[test]
fn response_at_threshold_is_compressed() {
let resp = text_response_with_size(MIN_COMPRESS_SIZE);
let headers = accept_gzip_headers();
let result = auto(&Method::GET, &headers, CompressionLevel::Default, resp).unwrap();
assert!(
result.headers().get(CONTENT_ENCODING).is_some(),
"responses at exactly {MIN_COMPRESS_SIZE} bytes must be compressed"
);
}
#[test]
fn response_above_threshold_is_compressed() {
let resp = text_response_with_size(MIN_COMPRESS_SIZE + 1);
let headers = accept_gzip_headers();
let result = auto(&Method::GET, &headers, CompressionLevel::Default, resp).unwrap();
assert!(
result.headers().get(CONTENT_ENCODING).is_some(),
"responses above {MIN_COMPRESS_SIZE} bytes must be compressed"
);
}
#[test]
fn response_without_content_length_is_compressed() {
let resp = text_response_without_length();
let headers = accept_gzip_headers();
let result = auto(&Method::GET, &headers, CompressionLevel::Default, resp).unwrap();
assert!(
result.headers().get(CONTENT_ENCODING).is_some(),
"responses without Content-Length must still be compressed"
);
}
#[test]
fn small_response_head_method_is_not_compressed() {
let resp = text_response_with_size(MIN_COMPRESS_SIZE - 1);
let headers = accept_gzip_headers();
let result = auto(&Method::HEAD, &headers, CompressionLevel::Default, resp).unwrap();
assert!(
result.headers().get(CONTENT_ENCODING).is_none(),
"HEAD requests are never compressed regardless of size"
);
}
#[test]
fn non_compressible_content_type_is_not_compressed() {
let mut resp = Response::new(Body::from(vec![b'x'; MIN_COMPRESS_SIZE + 100]));
resp.headers_mut()
.insert(CONTENT_TYPE, "image/png".parse().unwrap());
resp.headers_mut().insert(
CONTENT_LENGTH,
(MIN_COMPRESS_SIZE + 100).to_string().parse().unwrap(),
);
let headers = accept_gzip_headers();
let result = auto(&Method::GET, &headers, CompressionLevel::Default, resp).unwrap();
assert!(
result.headers().get(CONTENT_ENCODING).is_none(),
"non-compressible content-types are never compressed"
);
}
}