Skip to main content

static_web_server/
compression.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2// This file is part of Static Web Server.
3// See https://static-web-server.net/ for more information
4// Copyright (C) 2019-present Jose Quintana <joseluisq.net>
5
6//! Auto-compression module to compress responses body.
7//!
8
9// Part of the file is borrowed from <https://github.com/seanmonstar/warp/pull/513>*
10
11#[cfg(any(feature = "compression", feature = "compression-brotli"))]
12use async_compression::tokio::bufread::BrotliEncoder;
13#[cfg(any(feature = "compression", feature = "compression-deflate"))]
14use async_compression::tokio::bufread::DeflateEncoder;
15#[cfg(any(feature = "compression", feature = "compression-gzip"))]
16use async_compression::tokio::bufread::GzipEncoder;
17#[cfg(any(feature = "compression", feature = "compression-zstd"))]
18use async_compression::tokio::bufread::ZstdEncoder;
19
20use bytes::Bytes;
21use futures_util::Stream;
22use headers::{ContentType, HeaderMap, HeaderMapExt, HeaderValue};
23use hyper::{
24    Body, Method, Request, Response, StatusCode,
25    header::{CONTENT_ENCODING, CONTENT_LENGTH},
26};
27use mime_guess::{Mime, mime};
28use pin_project::pin_project;
29use std::pin::Pin;
30use std::task::{Context, Poll};
31use tokio_util::io::{ReaderStream, StreamReader};
32
33use crate::{
34    Error, Result, error_page,
35    handler::RequestHandlerOpts,
36    headers_ext::{AcceptEncoding, ContentCoding},
37    http_ext::MethodExt,
38    settings::CompressionLevel,
39};
40
41/// Contains a fixed list of common text-based MIME types that aren't recognizable in a generic way.
42const TEXT_MIME_TYPES: [&str; 8] = [
43    "application/rtf",
44    "application/javascript",
45    "application/json",
46    "application/xml",
47    "font/ttf",
48    "application/font-sfnt",
49    "application/vnd.ms-fontobject",
50    "application/wasm",
51];
52
53/// List of encodings that can be handled given enabled features.
54const AVAILABLE_ENCODINGS: &[ContentCoding] = &[
55    #[cfg(any(feature = "compression", feature = "compression-deflate"))]
56    ContentCoding::DEFLATE,
57    #[cfg(any(feature = "compression", feature = "compression-gzip"))]
58    ContentCoding::GZIP,
59    #[cfg(any(feature = "compression", feature = "compression-brotli"))]
60    ContentCoding::BROTLI,
61    #[cfg(any(feature = "compression", feature = "compression-zstd"))]
62    ContentCoding::ZSTD,
63];
64
65/// Initializes dynamic compression.
66pub fn init(enabled: bool, level: CompressionLevel, handler_opts: &mut RequestHandlerOpts) {
67    handler_opts.compression = enabled;
68    handler_opts.compression_level = level;
69
70    const FORMATS: &[&str] = &[
71        #[cfg(any(feature = "compression", feature = "compression-deflate"))]
72        "deflate",
73        #[cfg(any(feature = "compression", feature = "compression-gzip"))]
74        "gzip",
75        #[cfg(any(feature = "compression", feature = "compression-brotli"))]
76        "brotli",
77        #[cfg(any(feature = "compression", feature = "compression-zstd"))]
78        "zstd",
79    ];
80    tracing::info!(
81        "auto compression: enabled={enabled}, formats={}, compression level={level:?}",
82        FORMATS.join(",")
83    );
84}
85
86/// Post-processing to dynamically compress the response if necessary.
87pub(crate) fn post_process<T>(
88    opts: &RequestHandlerOpts,
89    req: &Request<T>,
90    mut resp: Response<Body>,
91) -> Result<Response<Body>, Error> {
92    if !opts.compression {
93        return Ok(resp);
94    }
95
96    let is_precompressed = resp.headers().get(CONTENT_ENCODING).is_some();
97    if is_precompressed {
98        return Ok(resp);
99    }
100
101    // Compression content encoding varies so use a `Vary` header
102    let enc = HeaderValue::from_name(hyper::header::ACCEPT_ENCODING);
103    let value = resp.headers().get(hyper::header::VARY).map_or(enc, |h| {
104        let mut a = h.to_str().unwrap_or_default().to_owned();
105        let b = hyper::header::ACCEPT_ENCODING.as_str();
106        if !a.contains(b) {
107            if !a.is_empty() {
108                a.push(',');
109            }
110            a.push_str(b);
111        }
112        HeaderValue::from_str(a.as_str()).unwrap()
113    });
114
115    resp.headers_mut().insert(hyper::header::VARY, value);
116
117    // Auto compression based on the `Accept-Encoding` header
118    match auto(req.method(), req.headers(), opts.compression_level, resp) {
119        Ok(resp) => Ok(resp),
120        Err(err) => {
121            tracing::error!("error during body compression: {:?}", err);
122            error_page::error_response(
123                req.uri(),
124                req.method(),
125                &StatusCode::INTERNAL_SERVER_ERROR,
126                &opts.page404,
127                &opts.page50x,
128            )
129        }
130    }
131}
132
133/// Create a wrapping handler that compresses the Body of a [`hyper::Response`]
134/// using gzip, `deflate`, `brotli` or `zstd` if is specified in the `Accept-Encoding` header, adding
135/// `content-encoding: <coding>` to the Response's [`HeaderMap`].
136/// It also provides the ability to apply compression for text-based MIME types only.
137pub fn auto(
138    method: &Method,
139    headers: &HeaderMap<HeaderValue>,
140    level: CompressionLevel,
141    resp: Response<Body>,
142) -> Result<Response<Body>> {
143    // Skip compression for HEAD and OPTIONS request methods
144    if method.is_head() || method.is_options() {
145        return Ok(resp);
146    }
147
148    // Compress response based on Accept-Encoding header
149    if let Some(encoding) = get_preferred_encoding(headers) {
150        tracing::trace!(
151            "preferred encoding selected from the accept-encoding header: {:?}",
152            encoding
153        );
154
155        // Skip compression for non-text-based MIME types
156        if let Some(content_type) = resp.headers().typed_get::<ContentType>()
157            && !is_text(Mime::from(content_type))
158        {
159            return Ok(resp);
160        }
161
162        #[cfg(any(feature = "compression", feature = "compression-gzip"))]
163        if encoding == ContentCoding::GZIP {
164            let (head, body) = resp.into_parts();
165            return Ok(gzip(head, body.into(), level));
166        }
167
168        #[cfg(any(feature = "compression", feature = "compression-deflate"))]
169        if encoding == ContentCoding::DEFLATE {
170            let (head, body) = resp.into_parts();
171            return Ok(deflate(head, body.into(), level));
172        }
173
174        #[cfg(any(feature = "compression", feature = "compression-brotli"))]
175        if encoding == ContentCoding::BROTLI {
176            let (head, body) = resp.into_parts();
177            return Ok(brotli(head, body.into(), level));
178        }
179
180        #[cfg(any(feature = "compression", feature = "compression-zstd"))]
181        if encoding == ContentCoding::ZSTD {
182            let (head, body) = resp.into_parts();
183            return Ok(zstd(head, body.into(), level));
184        }
185
186        tracing::trace!(
187            "no compression feature matched the preferred encoding, probably not enabled or unsupported"
188        );
189    }
190
191    Ok(resp)
192}
193
194/// Checks whether the MIME type corresponds to any of the known text types.
195fn is_text(mime: Mime) -> bool {
196    mime.type_() == mime::TEXT
197        || mime
198            .suffix()
199            .is_some_and(|suffix| suffix == mime::XML || suffix == mime::JSON)
200        || TEXT_MIME_TYPES.contains(&mime.essence_str())
201}
202
203/// Create a wrapping handler that compresses the Body of a [`Response`].
204/// using gzip, adding `content-encoding: gzip` to the Response's [`HeaderMap`].
205#[cfg(any(feature = "compression", feature = "compression-gzip"))]
206#[cfg_attr(
207    docsrs,
208    doc(cfg(any(feature = "compression", feature = "compression-gzip")))
209)]
210pub fn gzip(
211    mut head: http::response::Parts,
212    body: CompressableBody<Body, hyper::Error>,
213    level: CompressionLevel,
214) -> Response<Body> {
215    const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
216
217    tracing::trace!("compressing response body on the fly using GZIP");
218
219    let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
220    let body = Body::wrap_stream(ReaderStream::new(GzipEncoder::with_quality(
221        StreamReader::new(body),
222        level,
223    )));
224    let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::GZIP);
225    head.headers.remove(CONTENT_LENGTH);
226    head.headers.insert(CONTENT_ENCODING, header);
227    Response::from_parts(head, body)
228}
229
230/// Create a wrapping handler that compresses the Body of a [`Response`].
231/// using deflate, adding `content-encoding: deflate` to the Response's [`HeaderMap`].
232#[cfg(any(feature = "compression", feature = "compression-deflate"))]
233#[cfg_attr(
234    docsrs,
235    doc(cfg(any(feature = "compression", feature = "compression-deflate")))
236)]
237pub fn deflate(
238    mut head: http::response::Parts,
239    body: CompressableBody<Body, hyper::Error>,
240    level: CompressionLevel,
241) -> Response<Body> {
242    const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
243
244    tracing::trace!("compressing response body on the fly using DEFLATE");
245
246    let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
247    let body = Body::wrap_stream(ReaderStream::new(DeflateEncoder::with_quality(
248        StreamReader::new(body),
249        level,
250    )));
251    let header = create_encoding_header(
252        head.headers.remove(CONTENT_ENCODING),
253        ContentCoding::DEFLATE,
254    );
255    head.headers.remove(CONTENT_LENGTH);
256    head.headers.insert(CONTENT_ENCODING, header);
257    Response::from_parts(head, body)
258}
259
260/// Create a wrapping handler that compresses the Body of a [`Response`].
261/// using brotli, adding `content-encoding: br` to the Response's [`HeaderMap`].
262#[cfg(any(feature = "compression", feature = "compression-brotli"))]
263#[cfg_attr(
264    docsrs,
265    doc(cfg(any(feature = "compression", feature = "compression-brotli")))
266)]
267pub fn brotli(
268    mut head: http::response::Parts,
269    body: CompressableBody<Body, hyper::Error>,
270    level: CompressionLevel,
271) -> Response<Body> {
272    const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
273
274    tracing::trace!("compressing response body on the fly using BROTLI");
275
276    let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
277    let body = Body::wrap_stream(ReaderStream::new(BrotliEncoder::with_quality(
278        StreamReader::new(body),
279        level,
280    )));
281    let header =
282        create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::BROTLI);
283    head.headers.remove(CONTENT_LENGTH);
284    head.headers.insert(CONTENT_ENCODING, header);
285    Response::from_parts(head, body)
286}
287
288/// Create a wrapping handler that compresses the Body of a [`Response`].
289/// using zstd, adding `content-encoding: zstd` to the Response's [`HeaderMap`].
290#[cfg(any(feature = "compression", feature = "compression-zstd"))]
291#[cfg_attr(
292    docsrs,
293    doc(cfg(any(feature = "compression", feature = "compression-zstd")))
294)]
295pub fn zstd(
296    mut head: http::response::Parts,
297    body: CompressableBody<Body, hyper::Error>,
298    level: CompressionLevel,
299) -> Response<Body> {
300    const DEFAULT_COMPRESSION_LEVEL: i32 = 3;
301
302    tracing::trace!("compressing response body on the fly using ZSTD");
303
304    let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
305    let body = Body::wrap_stream(ReaderStream::new(ZstdEncoder::with_quality(
306        StreamReader::new(body),
307        level,
308    )));
309    let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::ZSTD);
310    head.headers.remove(CONTENT_LENGTH);
311    head.headers.insert(CONTENT_ENCODING, header);
312    Response::from_parts(head, body)
313}
314
315/// Given an optional existing encoding header, appends to the existing or creates a new one.
316pub fn create_encoding_header(existing: Option<HeaderValue>, coding: ContentCoding) -> HeaderValue {
317    if let Some(val) = existing
318        && let Ok(str_val) = val.to_str()
319    {
320        return HeaderValue::from_str(&[str_val, ", ", coding.as_str()].concat())
321            .unwrap_or_else(|_| coding.into());
322    }
323    coding.into()
324}
325
326/// Try to get the preferred `content-encoding` via the `accept-encoding` header.
327#[inline(always)]
328pub fn get_preferred_encoding(headers: &HeaderMap<HeaderValue>) -> Option<ContentCoding> {
329    if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
330        tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
331
332        for encoding in accept_encoding.sorted_encodings() {
333            if AVAILABLE_ENCODINGS.contains(&encoding) {
334                return Some(encoding);
335            }
336        }
337    }
338    None
339}
340
341/// Get the `content-encodings` via the `accept-encoding` header.
342#[inline(always)]
343pub fn get_encodings(headers: &HeaderMap<HeaderValue>) -> Vec<ContentCoding> {
344    if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
345        tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
346
347        return accept_encoding
348            .sorted_encodings()
349            .filter(|encoding| AVAILABLE_ENCODINGS.contains(encoding))
350            .collect::<Vec<_>>();
351    }
352    vec![]
353}
354
355/// A wrapper around any type that implements [`Stream`](futures_util::Stream) to be
356/// compatible with async_compression's `Stream` based encoders.
357#[pin_project]
358#[derive(Debug)]
359pub struct CompressableBody<S, E>
360where
361    S: Stream<Item = Result<Bytes, E>>,
362    E: std::error::Error,
363{
364    #[pin]
365    body: S,
366}
367
368impl<S, E> Stream for CompressableBody<S, E>
369where
370    S: Stream<Item = Result<Bytes, E>>,
371    E: std::error::Error,
372{
373    type Item = std::io::Result<Bytes>;
374
375    fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
376        use std::io::{Error, ErrorKind};
377
378        let pin = self.project();
379        S::poll_next(pin.body, ctx).map_err(|_| Error::from(ErrorKind::InvalidData))
380    }
381}
382
383impl From<Body> for CompressableBody<Body, hyper::Error> {
384    #[inline(always)]
385    fn from(body: Body) -> Self {
386        CompressableBody { body }
387    }
388}