#[cfg(any(feature = "compression", feature = "compression-brotli"))]
use async_compression::tokio::bufread::BrotliEncoder;
#[cfg(any(feature = "compression", feature = "compression-deflate"))]
use async_compression::tokio::bufread::DeflateEncoder;
#[cfg(any(feature = "compression", feature = "compression-gzip"))]
use async_compression::tokio::bufread::GzipEncoder;
#[cfg(any(feature = "compression", feature = "compression-zstd"))]
use async_compression::tokio::bufread::ZstdEncoder;
use bytes::Bytes;
use futures_util::Stream;
use headers::{ContentType, HeaderMap, HeaderMapExt, HeaderValue};
use hyper::{
header::{CONTENT_ENCODING, CONTENT_LENGTH},
Body, Method, Request, Response, StatusCode,
};
use lazy_static::lazy_static;
use mime_guess::{mime, Mime};
use pin_project::pin_project;
use std::collections::HashSet;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio_util::io::{ReaderStream, StreamReader};
use crate::{
error_page,
handler::RequestHandlerOpts,
headers_ext::{AcceptEncoding, ContentCoding},
http_ext::MethodExt,
settings::CompressionLevel,
Error, Result,
};
lazy_static! {
static ref TEXT_MIME_TYPES: HashSet<&'static str> = [
"application/rtf",
"application/javascript",
"application/json",
"application/xml",
"font/ttf",
"application/font-sfnt",
"application/vnd.ms-fontobject",
"application/wasm",
].into_iter().collect();
}
const AVAILABLE_ENCODINGS: &[ContentCoding] = &[
#[cfg(any(feature = "compression", feature = "compression-deflate"))]
ContentCoding::DEFLATE,
#[cfg(any(feature = "compression", feature = "compression-gzip"))]
ContentCoding::GZIP,
#[cfg(any(feature = "compression", feature = "compression-brotli"))]
ContentCoding::BROTLI,
#[cfg(any(feature = "compression", feature = "compression-zstd"))]
ContentCoding::ZSTD,
];
pub fn init(enabled: bool, level: CompressionLevel, handler_opts: &mut RequestHandlerOpts) {
handler_opts.compression = enabled;
handler_opts.compression_level = level;
const FORMATS: &[&str] = &[
#[cfg(any(feature = "compression", feature = "compression-deflate"))]
"deflate",
#[cfg(any(feature = "compression", feature = "compression-gzip"))]
"gzip",
#[cfg(any(feature = "compression", feature = "compression-brotli"))]
"brotli",
#[cfg(any(feature = "compression", feature = "compression-zstd"))]
"zstd",
];
tracing::info!(
"auto compression: enabled={enabled}, formats={}, compression level={level:?}",
FORMATS.join(",")
);
}
pub(crate) fn post_process<T>(
opts: &RequestHandlerOpts,
req: &Request<T>,
mut resp: Response<Body>,
) -> Result<Response<Body>, Error> {
if !opts.compression {
return Ok(resp);
}
let is_precompressed = resp.headers().get(CONTENT_ENCODING).is_some();
if is_precompressed {
return Ok(resp);
}
let value = resp.headers().get(hyper::header::VARY).map_or(
HeaderValue::from_name(hyper::header::ACCEPT_ENCODING),
|h| {
let mut s = h.to_str().unwrap_or_default().to_owned();
s.push(',');
s.push_str(hyper::header::ACCEPT_ENCODING.as_str());
HeaderValue::from_str(s.as_str()).unwrap()
},
);
resp.headers_mut().insert(hyper::header::VARY, value);
match auto(req.method(), req.headers(), opts.compression_level, resp) {
Ok(resp) => Ok(resp),
Err(err) => {
tracing::error!("error during body compression: {:?}", err);
error_page::error_response(
req.uri(),
req.method(),
&StatusCode::INTERNAL_SERVER_ERROR,
&opts.page404,
&opts.page50x,
)
}
}
}
pub fn auto(
method: &Method,
headers: &HeaderMap<HeaderValue>,
level: CompressionLevel,
resp: Response<Body>,
) -> Result<Response<Body>> {
if method.is_head() || method.is_options() {
return Ok(resp);
}
if let Some(encoding) = get_preferred_encoding(headers) {
tracing::trace!(
"preferred encoding selected from the accept-encoding header: {:?}",
encoding
);
if let Some(content_type) = resp.headers().typed_get::<ContentType>() {
if !is_text(Mime::from(content_type)) {
return Ok(resp);
}
}
#[cfg(any(feature = "compression", feature = "compression-gzip"))]
if encoding == ContentCoding::GZIP {
let (head, body) = resp.into_parts();
return Ok(gzip(head, body.into(), level));
}
#[cfg(any(feature = "compression", feature = "compression-deflate"))]
if encoding == ContentCoding::DEFLATE {
let (head, body) = resp.into_parts();
return Ok(deflate(head, body.into(), level));
}
#[cfg(any(feature = "compression", feature = "compression-brotli"))]
if encoding == ContentCoding::BROTLI {
let (head, body) = resp.into_parts();
return Ok(brotli(head, body.into(), level));
}
#[cfg(any(feature = "compression", feature = "compression-zstd"))]
if encoding == ContentCoding::ZSTD {
let (head, body) = resp.into_parts();
return Ok(zstd(head, body.into(), level));
}
tracing::trace!("no compression feature matched the preferred encoding, probably not enabled or unsupported");
}
Ok(resp)
}
fn is_text(mime: Mime) -> bool {
mime.type_() == mime::TEXT
|| mime
.suffix()
.is_some_and(|suffix| suffix == mime::XML || suffix == mime::JSON)
|| TEXT_MIME_TYPES.contains(mime.essence_str())
}
#[cfg(any(feature = "compression", feature = "compression-gzip"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "compression", feature = "compression-gzip")))
)]
pub fn gzip(
mut head: http::response::Parts,
body: CompressableBody<Body, hyper::Error>,
level: CompressionLevel,
) -> Response<Body> {
const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
tracing::trace!("compressing response body on the fly using GZIP");
let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
let body = Body::wrap_stream(ReaderStream::new(GzipEncoder::with_quality(
StreamReader::new(body),
level,
)));
let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::GZIP);
head.headers.remove(CONTENT_LENGTH);
head.headers.insert(CONTENT_ENCODING, header);
Response::from_parts(head, body)
}
#[cfg(any(feature = "compression", feature = "compression-deflate"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "compression", feature = "compression-deflate")))
)]
pub fn deflate(
mut head: http::response::Parts,
body: CompressableBody<Body, hyper::Error>,
level: CompressionLevel,
) -> Response<Body> {
const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
tracing::trace!("compressing response body on the fly using DEFLATE");
let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
let body = Body::wrap_stream(ReaderStream::new(DeflateEncoder::with_quality(
StreamReader::new(body),
level,
)));
let header = create_encoding_header(
head.headers.remove(CONTENT_ENCODING),
ContentCoding::DEFLATE,
);
head.headers.remove(CONTENT_LENGTH);
head.headers.insert(CONTENT_ENCODING, header);
Response::from_parts(head, body)
}
#[cfg(any(feature = "compression", feature = "compression-brotli"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "compression", feature = "compression-brotli")))
)]
pub fn brotli(
mut head: http::response::Parts,
body: CompressableBody<Body, hyper::Error>,
level: CompressionLevel,
) -> Response<Body> {
const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
tracing::trace!("compressing response body on the fly using BROTLI");
let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
let body = Body::wrap_stream(ReaderStream::new(BrotliEncoder::with_quality(
StreamReader::new(body),
level,
)));
let header =
create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::BROTLI);
head.headers.remove(CONTENT_LENGTH);
head.headers.insert(CONTENT_ENCODING, header);
Response::from_parts(head, body)
}
#[cfg(any(feature = "compression", feature = "compression-zstd"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "compression", feature = "compression-zstd")))
)]
pub fn zstd(
mut head: http::response::Parts,
body: CompressableBody<Body, hyper::Error>,
level: CompressionLevel,
) -> Response<Body> {
const DEFAULT_COMPRESSION_LEVEL: i32 = 3;
tracing::trace!("compressing response body on the fly using ZSTD");
let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
let body = Body::wrap_stream(ReaderStream::new(ZstdEncoder::with_quality(
StreamReader::new(body),
level,
)));
let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::ZSTD);
head.headers.remove(CONTENT_LENGTH);
head.headers.insert(CONTENT_ENCODING, header);
Response::from_parts(head, body)
}
pub fn create_encoding_header(existing: Option<HeaderValue>, coding: ContentCoding) -> HeaderValue {
if let Some(val) = existing {
if let Ok(str_val) = val.to_str() {
return HeaderValue::from_str(&[str_val, ", ", coding.as_str()].concat())
.unwrap_or_else(|_| coding.into());
}
}
coding.into()
}
#[inline(always)]
pub fn get_preferred_encoding(headers: &HeaderMap<HeaderValue>) -> Option<ContentCoding> {
if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
for encoding in accept_encoding.sorted_encodings() {
if AVAILABLE_ENCODINGS.contains(&encoding) {
return Some(encoding);
}
}
}
None
}
#[inline(always)]
pub fn get_encodings(headers: &HeaderMap<HeaderValue>) -> Vec<ContentCoding> {
if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
return accept_encoding
.sorted_encodings()
.filter(|encoding| AVAILABLE_ENCODINGS.contains(encoding))
.collect::<Vec<_>>();
}
vec![]
}
#[pin_project]
#[derive(Debug)]
pub struct CompressableBody<S, E>
where
S: Stream<Item = Result<Bytes, E>>,
E: std::error::Error,
{
#[pin]
body: S,
}
impl<S, E> Stream for CompressableBody<S, E>
where
S: Stream<Item = Result<Bytes, E>>,
E: std::error::Error,
{
type Item = std::io::Result<Bytes>;
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use std::io::{Error, ErrorKind};
let pin = self.project();
S::poll_next(pin.body, ctx).map_err(|_| Error::from(ErrorKind::InvalidData))
}
}
impl From<Body> for CompressableBody<Body, hyper::Error> {
#[inline(always)]
fn from(body: Body) -> Self {
CompressableBody { body }
}
}