1#[cfg(any(feature = "compression", feature = "compression-brotli"))]
12use async_compression::tokio::bufread::BrotliEncoder;
13#[cfg(any(feature = "compression", feature = "compression-deflate"))]
14use async_compression::tokio::bufread::DeflateEncoder;
15#[cfg(any(feature = "compression", feature = "compression-gzip"))]
16use async_compression::tokio::bufread::GzipEncoder;
17#[cfg(any(feature = "compression", feature = "compression-zstd"))]
18use async_compression::tokio::bufread::ZstdEncoder;
19
20use bytes::Bytes;
21use futures_util::Stream;
22use headers::{ContentType, HeaderMap, HeaderMapExt, HeaderValue};
23use hyper::{
24 Body, Method, Request, Response, StatusCode,
25 header::{CONTENT_ENCODING, CONTENT_LENGTH},
26};
27use mime_guess::{Mime, mime};
28use pin_project::pin_project;
29use std::pin::Pin;
30use std::task::{Context, Poll};
31use tokio_util::io::{ReaderStream, StreamReader};
32
33use crate::{
34 Error, Result, error_page,
35 handler::RequestHandlerOpts,
36 headers_ext::{AcceptEncoding, ContentCoding},
37 http_ext::MethodExt,
38 settings::CompressionLevel,
39};
40
41const TEXT_MIME_TYPES: [&str; 8] = [
43 "application/rtf",
44 "application/javascript",
45 "application/json",
46 "application/xml",
47 "font/ttf",
48 "application/font-sfnt",
49 "application/vnd.ms-fontobject",
50 "application/wasm",
51];
52
53const AVAILABLE_ENCODINGS: &[ContentCoding] = &[
55 #[cfg(any(feature = "compression", feature = "compression-deflate"))]
56 ContentCoding::DEFLATE,
57 #[cfg(any(feature = "compression", feature = "compression-gzip"))]
58 ContentCoding::GZIP,
59 #[cfg(any(feature = "compression", feature = "compression-brotli"))]
60 ContentCoding::BROTLI,
61 #[cfg(any(feature = "compression", feature = "compression-zstd"))]
62 ContentCoding::ZSTD,
63];
64
65pub fn init(enabled: bool, level: CompressionLevel, handler_opts: &mut RequestHandlerOpts) {
67 handler_opts.compression = enabled;
68 handler_opts.compression_level = level;
69
70 const FORMATS: &[&str] = &[
71 #[cfg(any(feature = "compression", feature = "compression-deflate"))]
72 "deflate",
73 #[cfg(any(feature = "compression", feature = "compression-gzip"))]
74 "gzip",
75 #[cfg(any(feature = "compression", feature = "compression-brotli"))]
76 "brotli",
77 #[cfg(any(feature = "compression", feature = "compression-zstd"))]
78 "zstd",
79 ];
80 tracing::info!(
81 "auto compression: enabled={enabled}, formats={}, compression level={level:?}",
82 FORMATS.join(",")
83 );
84}
85
86pub(crate) fn post_process<T>(
88 opts: &RequestHandlerOpts,
89 req: &Request<T>,
90 mut resp: Response<Body>,
91) -> Result<Response<Body>, Error> {
92 if !opts.compression {
93 return Ok(resp);
94 }
95
96 let is_precompressed = resp.headers().get(CONTENT_ENCODING).is_some();
97 if is_precompressed {
98 return Ok(resp);
99 }
100
101 let enc = HeaderValue::from_name(hyper::header::ACCEPT_ENCODING);
103 let value = resp.headers().get(hyper::header::VARY).map_or(enc, |h| {
104 let mut a = h.to_str().unwrap_or_default().to_owned();
105 let b = hyper::header::ACCEPT_ENCODING.as_str();
106 if !a.contains(b) {
107 if !a.is_empty() {
108 a.push(',');
109 }
110 a.push_str(b);
111 }
112 HeaderValue::from_str(a.as_str()).unwrap()
113 });
114
115 resp.headers_mut().insert(hyper::header::VARY, value);
116
117 match auto(req.method(), req.headers(), opts.compression_level, resp) {
119 Ok(resp) => Ok(resp),
120 Err(err) => {
121 tracing::error!("error during body compression: {:?}", err);
122 error_page::error_response(
123 req.uri(),
124 req.method(),
125 &StatusCode::INTERNAL_SERVER_ERROR,
126 &opts.page404,
127 &opts.page50x,
128 )
129 }
130 }
131}
132
133pub fn auto(
138 method: &Method,
139 headers: &HeaderMap<HeaderValue>,
140 level: CompressionLevel,
141 resp: Response<Body>,
142) -> Result<Response<Body>> {
143 if method.is_head() || method.is_options() {
145 return Ok(resp);
146 }
147
148 if let Some(encoding) = get_preferred_encoding(headers) {
150 tracing::trace!(
151 "preferred encoding selected from the accept-encoding header: {:?}",
152 encoding
153 );
154
155 if let Some(content_type) = resp.headers().typed_get::<ContentType>()
157 && !is_text(Mime::from(content_type))
158 {
159 return Ok(resp);
160 }
161
162 #[cfg(any(feature = "compression", feature = "compression-gzip"))]
163 if encoding == ContentCoding::GZIP {
164 let (head, body) = resp.into_parts();
165 return Ok(gzip(head, body.into(), level));
166 }
167
168 #[cfg(any(feature = "compression", feature = "compression-deflate"))]
169 if encoding == ContentCoding::DEFLATE {
170 let (head, body) = resp.into_parts();
171 return Ok(deflate(head, body.into(), level));
172 }
173
174 #[cfg(any(feature = "compression", feature = "compression-brotli"))]
175 if encoding == ContentCoding::BROTLI {
176 let (head, body) = resp.into_parts();
177 return Ok(brotli(head, body.into(), level));
178 }
179
180 #[cfg(any(feature = "compression", feature = "compression-zstd"))]
181 if encoding == ContentCoding::ZSTD {
182 let (head, body) = resp.into_parts();
183 return Ok(zstd(head, body.into(), level));
184 }
185
186 tracing::trace!(
187 "no compression feature matched the preferred encoding, probably not enabled or unsupported"
188 );
189 }
190
191 Ok(resp)
192}
193
194fn is_text(mime: Mime) -> bool {
196 mime.type_() == mime::TEXT
197 || mime
198 .suffix()
199 .is_some_and(|suffix| suffix == mime::XML || suffix == mime::JSON)
200 || TEXT_MIME_TYPES.contains(&mime.essence_str())
201}
202
203#[cfg(any(feature = "compression", feature = "compression-gzip"))]
206#[cfg_attr(
207 docsrs,
208 doc(cfg(any(feature = "compression", feature = "compression-gzip")))
209)]
210pub fn gzip(
211 mut head: http::response::Parts,
212 body: CompressableBody<Body, hyper::Error>,
213 level: CompressionLevel,
214) -> Response<Body> {
215 const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
216
217 tracing::trace!("compressing response body on the fly using GZIP");
218
219 let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
220 let body = Body::wrap_stream(ReaderStream::new(GzipEncoder::with_quality(
221 StreamReader::new(body),
222 level,
223 )));
224 let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::GZIP);
225 head.headers.remove(CONTENT_LENGTH);
226 head.headers.insert(CONTENT_ENCODING, header);
227 Response::from_parts(head, body)
228}
229
230#[cfg(any(feature = "compression", feature = "compression-deflate"))]
233#[cfg_attr(
234 docsrs,
235 doc(cfg(any(feature = "compression", feature = "compression-deflate")))
236)]
237pub fn deflate(
238 mut head: http::response::Parts,
239 body: CompressableBody<Body, hyper::Error>,
240 level: CompressionLevel,
241) -> Response<Body> {
242 const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
243
244 tracing::trace!("compressing response body on the fly using DEFLATE");
245
246 let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
247 let body = Body::wrap_stream(ReaderStream::new(DeflateEncoder::with_quality(
248 StreamReader::new(body),
249 level,
250 )));
251 let header = create_encoding_header(
252 head.headers.remove(CONTENT_ENCODING),
253 ContentCoding::DEFLATE,
254 );
255 head.headers.remove(CONTENT_LENGTH);
256 head.headers.insert(CONTENT_ENCODING, header);
257 Response::from_parts(head, body)
258}
259
260#[cfg(any(feature = "compression", feature = "compression-brotli"))]
263#[cfg_attr(
264 docsrs,
265 doc(cfg(any(feature = "compression", feature = "compression-brotli")))
266)]
267pub fn brotli(
268 mut head: http::response::Parts,
269 body: CompressableBody<Body, hyper::Error>,
270 level: CompressionLevel,
271) -> Response<Body> {
272 const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
273
274 tracing::trace!("compressing response body on the fly using BROTLI");
275
276 let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
277 let body = Body::wrap_stream(ReaderStream::new(BrotliEncoder::with_quality(
278 StreamReader::new(body),
279 level,
280 )));
281 let header =
282 create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::BROTLI);
283 head.headers.remove(CONTENT_LENGTH);
284 head.headers.insert(CONTENT_ENCODING, header);
285 Response::from_parts(head, body)
286}
287
288#[cfg(any(feature = "compression", feature = "compression-zstd"))]
291#[cfg_attr(
292 docsrs,
293 doc(cfg(any(feature = "compression", feature = "compression-zstd")))
294)]
295pub fn zstd(
296 mut head: http::response::Parts,
297 body: CompressableBody<Body, hyper::Error>,
298 level: CompressionLevel,
299) -> Response<Body> {
300 const DEFAULT_COMPRESSION_LEVEL: i32 = 3;
301
302 tracing::trace!("compressing response body on the fly using ZSTD");
303
304 let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
305 let body = Body::wrap_stream(ReaderStream::new(ZstdEncoder::with_quality(
306 StreamReader::new(body),
307 level,
308 )));
309 let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::ZSTD);
310 head.headers.remove(CONTENT_LENGTH);
311 head.headers.insert(CONTENT_ENCODING, header);
312 Response::from_parts(head, body)
313}
314
315pub fn create_encoding_header(existing: Option<HeaderValue>, coding: ContentCoding) -> HeaderValue {
317 if let Some(val) = existing
318 && let Ok(str_val) = val.to_str()
319 {
320 return HeaderValue::from_str(&[str_val, ", ", coding.as_str()].concat())
321 .unwrap_or_else(|_| coding.into());
322 }
323 coding.into()
324}
325
326#[inline(always)]
328pub fn get_preferred_encoding(headers: &HeaderMap<HeaderValue>) -> Option<ContentCoding> {
329 if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
330 tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
331
332 for encoding in accept_encoding.sorted_encodings() {
333 if AVAILABLE_ENCODINGS.contains(&encoding) {
334 return Some(encoding);
335 }
336 }
337 }
338 None
339}
340
341#[inline(always)]
343pub fn get_encodings(headers: &HeaderMap<HeaderValue>) -> Vec<ContentCoding> {
344 if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
345 tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
346
347 return accept_encoding
348 .sorted_encodings()
349 .filter(|encoding| AVAILABLE_ENCODINGS.contains(encoding))
350 .collect::<Vec<_>>();
351 }
352 vec![]
353}
354
355#[pin_project]
358#[derive(Debug)]
359pub struct CompressableBody<S, E>
360where
361 S: Stream<Item = Result<Bytes, E>>,
362 E: std::error::Error,
363{
364 #[pin]
365 body: S,
366}
367
368impl<S, E> Stream for CompressableBody<S, E>
369where
370 S: Stream<Item = Result<Bytes, E>>,
371 E: std::error::Error,
372{
373 type Item = std::io::Result<Bytes>;
374
375 fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
376 use std::io::{Error, ErrorKind};
377
378 let pin = self.project();
379 S::poll_next(pin.body, ctx).map_err(|_| Error::from(ErrorKind::InvalidData))
380 }
381}
382
383impl From<Body> for CompressableBody<Body, hyper::Error> {
384 #[inline(always)]
385 fn from(body: Body) -> Self {
386 CompressableBody { body }
387 }
388}