1#![forbid(unsafe_code)]
19#![deny(
20 missing_copy_implementations,
21 rustdoc::missing_crate_level_docs,
22 missing_debug_implementations,
23 nonstandard_style,
24 unused_qualifications
25)]
26#![warn(missing_docs)]
27
28#[cfg(test)]
29#[doc = include_str!("../README.md")]
30mod readme {}
31
32pub use async_compression::Level;
33use async_compression::futures::bufread::{BrotliEncoder, GzipEncoder, ZstdEncoder};
34use futures_lite::{
35 AsyncReadExt,
36 io::{BufReader, Cursor},
37};
38use std::{
39 collections::BTreeSet,
40 fmt::{self, Display, Formatter},
41 str::FromStr,
42};
43use trillium::{
44 Body, Conn, Handler, HeaderValues,
45 KnownHeaderName::{AcceptEncoding, ContentEncoding, ContentType, Vary},
46 conn_try, conn_unwrap,
47};
48
49#[derive(PartialEq, Eq, Clone, Copy, Debug, Ord, PartialOrd)]
51#[non_exhaustive]
52pub enum CompressionAlgorithm {
53 Brotli,
55
56 Gzip,
58
59 Zstd,
61}
62
63impl CompressionAlgorithm {
64 fn as_str(&self) -> &'static str {
65 match self {
66 CompressionAlgorithm::Brotli => "br",
67 CompressionAlgorithm::Gzip => "gzip",
68 CompressionAlgorithm::Zstd => "zstd",
69 }
70 }
71
72 fn from_str_exact(s: &str) -> Option<Self> {
73 match s {
74 "br" => Some(CompressionAlgorithm::Brotli),
75 "gzip" => Some(CompressionAlgorithm::Gzip),
76 "x-gzip" => Some(CompressionAlgorithm::Gzip),
77 "zstd" => Some(CompressionAlgorithm::Zstd),
78 _ => None,
79 }
80 }
81}
82
83impl AsRef<str> for CompressionAlgorithm {
84 fn as_ref(&self) -> &str {
85 self.as_str()
86 }
87}
88
89impl Display for CompressionAlgorithm {
90 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
91 f.write_str(self.as_str())
92 }
93}
94
95impl FromStr for CompressionAlgorithm {
96 type Err = String;
97
98 fn from_str(s: &str) -> Result<Self, Self::Err> {
99 Self::from_str_exact(s)
100 .or_else(|| Self::from_str_exact(&s.to_ascii_lowercase()))
101 .ok_or_else(|| format!("unrecognized coding {s}"))
102 }
103}
104
105#[derive(Clone, Debug)]
107pub struct Compression {
108 algorithms: BTreeSet<CompressionAlgorithm>,
109 brotli_level: Level,
110 gzip_level: Level,
111 zstd_level: Level,
112}
113
114impl Default for Compression {
115 fn default() -> Self {
116 use CompressionAlgorithm::*;
117 Self {
118 algorithms: [Zstd, Brotli, Gzip].into_iter().collect(),
119 brotli_level: Level::Precise(4),
123 gzip_level: Level::Default,
124 zstd_level: Level::Default,
125 }
126 }
127}
128
129impl Compression {
130 pub fn new() -> Self {
132 Self::default()
133 }
134
135 fn set_algorithms(&mut self, algos: &[CompressionAlgorithm]) {
136 self.algorithms = algos.iter().copied().collect();
137 }
138
139 pub fn with_algorithms(mut self, algorithms: &[CompressionAlgorithm]) -> Self {
143 self.set_algorithms(algorithms);
144 self
145 }
146
147 pub fn with_brotli_level(mut self, level: Level) -> Self {
152 self.brotli_level = level;
153 self
154 }
155
156 pub fn with_gzip_level(mut self, level: Level) -> Self {
159 self.gzip_level = level;
160 self
161 }
162
163 pub fn with_zstd_level(mut self, level: Level) -> Self {
166 self.zstd_level = level;
167 self
168 }
169
170 fn negotiate(&self, header: &str) -> Option<CompressionAlgorithm> {
171 parse_accept_encoding(header)
172 .into_iter()
173 .find_map(|(algo, _)| {
174 if self.algorithms.contains(&algo) {
175 Some(algo)
176 } else {
177 None
178 }
179 })
180 }
181}
182
183fn parse_accept_encoding(header: &str) -> Vec<(CompressionAlgorithm, u8)> {
184 let mut vec = header
185 .split(',')
186 .filter_map(|s| {
187 let mut iter = s.trim().split(';');
188 let (algo, q) = (iter.next()?, iter.next());
189 let algo = algo.trim().parse().ok()?;
190 let q = q
191 .and_then(|q| {
192 q.trim()
193 .strip_prefix("q=")
194 .and_then(|q| q.parse::<f32>().map(|f| (f * 100.0) as u8).ok())
195 })
196 .unwrap_or(100u8);
197 Some((algo, q))
198 })
199 .collect::<Vec<(CompressionAlgorithm, u8)>>();
200
201 vec.sort_by(|(algo_a, a), (algo_b, b)| match b.cmp(a) {
202 std::cmp::Ordering::Equal => algo_a.cmp(algo_b),
203 other => other,
204 });
205
206 vec
207}
208
209fn is_already_compressed(content_type: &str) -> bool {
215 let primary = content_type
216 .split(';')
217 .next()
218 .unwrap_or(content_type)
219 .trim();
220 matches!(
221 primary,
222 "image/png"
223 | "image/jpeg"
224 | "image/jpg"
225 | "image/gif"
226 | "image/webp"
227 | "image/avif"
228 | "image/heic"
229 | "image/heif"
230 | "image/apng"
231 | "image/x-icon"
232 | "video/mp4"
233 | "video/webm"
234 | "video/ogg"
235 | "video/quicktime"
236 | "video/x-msvideo"
237 | "audio/mpeg"
238 | "audio/ogg"
239 | "audio/webm"
240 | "audio/aac"
241 | "audio/flac"
242 | "audio/mp4"
243 | "font/woff"
244 | "font/woff2"
245 | "application/zip"
246 | "application/gzip"
247 | "application/x-gzip"
248 | "application/x-bzip2"
249 | "application/x-xz"
250 | "application/x-7z-compressed"
251 | "application/x-rar-compressed"
252 | "application/zstd"
253 ) || primary.starts_with("video/")
254 || primary.starts_with("audio/")
255}
256
257impl Handler for Compression {
258 async fn run(&self, mut conn: Conn) -> Conn {
259 if let Some(header) = conn
260 .request_headers()
261 .get_str(AcceptEncoding)
262 .and_then(|h| self.negotiate(h))
263 {
264 conn.insert_state(header);
265 }
266 conn
267 }
268
269 async fn before_send(&self, mut conn: Conn) -> Conn {
270 if conn.response_headers().get_str(ContentEncoding).is_some() {
273 return conn;
274 }
275
276 if conn
278 .response_headers()
279 .get_str(ContentType)
280 .is_some_and(is_already_compressed)
281 {
282 return conn;
283 }
284
285 let Some(algo) = conn.state::<CompressionAlgorithm>().copied() else {
286 return conn;
287 };
288
289 let mut body = conn_unwrap!(conn.take_response_body(), conn);
290 let mut compression_used = false;
291
292 if body.is_static() {
293 let bytes = body.static_bytes().unwrap();
294 let mut data = vec![];
295 match algo {
296 CompressionAlgorithm::Zstd => {
297 let mut encoder =
298 ZstdEncoder::with_quality(Cursor::new(bytes), self.zstd_level);
299 conn_try!(encoder.read_to_end(&mut data).await, conn);
300 }
301 CompressionAlgorithm::Brotli => {
302 let mut encoder =
303 BrotliEncoder::with_quality(Cursor::new(bytes), self.brotli_level);
304 conn_try!(encoder.read_to_end(&mut data).await, conn);
305 }
306 CompressionAlgorithm::Gzip => {
307 let mut encoder =
308 GzipEncoder::with_quality(Cursor::new(bytes), self.gzip_level);
309 conn_try!(encoder.read_to_end(&mut data).await, conn);
310 }
311 }
312 if data.len() < bytes.len() {
313 log::trace!(
314 "{} body from {} to {}",
315 algo.as_str(),
316 bytes.len(),
317 data.len()
318 );
319 compression_used = true;
320 body = Body::new_static(data);
321 }
322 } else if body.is_streaming() {
323 compression_used = true;
324 match algo {
325 CompressionAlgorithm::Zstd => {
326 body = Body::new_streaming(
327 ZstdEncoder::with_quality(
328 BufReader::new(body.into_reader()),
329 self.zstd_level,
330 ),
331 None,
332 );
333 }
334 CompressionAlgorithm::Brotli => {
335 body = Body::new_streaming(
336 BrotliEncoder::with_quality(
337 BufReader::new(body.into_reader()),
338 self.brotli_level,
339 ),
340 None,
341 );
342 }
343 CompressionAlgorithm::Gzip => {
344 body = Body::new_streaming(
345 GzipEncoder::with_quality(
346 BufReader::new(body.into_reader()),
347 self.gzip_level,
348 ),
349 None,
350 );
351 }
352 }
353 }
354
355 if compression_used {
356 let vary = conn
357 .response_headers()
358 .get_str(Vary)
359 .map(|vary| HeaderValues::from(format!("{vary}, Accept-Encoding")))
360 .unwrap_or_else(|| HeaderValues::from("Accept-Encoding"));
361
362 conn.response_headers_mut().extend([
363 (ContentEncoding, HeaderValues::from(algo.as_str())),
364 (Vary, vary),
365 ]);
366 }
367
368 conn.with_body(body)
369 }
370}
371
372pub fn compression() -> Compression {
374 Compression::new()
375}