tide_compress/
middleware.rs1use tide::http::cache::{CacheControl, CacheDirective};
2use tide::http::conditional::Vary;
3use tide::http::content::{AcceptEncoding, ContentEncoding, Encoding};
4use tide::http::{headers, Body, Method};
5use tide::{Middleware, Next, Request, Response};
6
7#[cfg(any(feature = "brotli", feature = "deflate", feature = "gzip"))]
8use async_compression::Level;
9#[cfg(any(feature = "brotli", feature = "deflate", feature = "gzip"))]
10use futures_lite::io::BufReader;
11
12#[cfg(feature = "brotli")]
13use async_compression::futures::bufread::BrotliEncoder;
14#[cfg(feature = "deflate")]
15use async_compression::futures::bufread::DeflateEncoder;
16#[cfg(feature = "gzip")]
17use async_compression::futures::bufread::GzipEncoder;
18
19#[cfg(feature = "regex-check")]
20use http_types::content::ContentType;
21#[cfg(feature = "regex-check")]
22use regex::{Regex, RegexBuilder};
23
24const THRESHOLD: usize = 1024;
25
26#[cfg(feature = "regex-check")]
30const CONTENT_TYPE_CHECK_PATTERN: &str = r"^text/|\+(?:json|text|xml)$";
31#[cfg(feature = "regex-check")]
32const EXTRACT_TYPE_PATTERN: &str = r"^\s*([^;\s]*)(?:;|\s|$)";
33
34#[derive(Clone, Debug)]
45pub struct CompressMiddleware {
46 threshold: usize,
47 #[cfg(feature = "regex-check")]
48 content_type_check: Option<Regex>,
49 #[cfg(feature = "regex-check")]
50 extract_type_regex: Regex,
51 #[cfg(feature = "brotli")]
52 brotli_quality: Level,
53 #[cfg(any(feature = "gzip", feature = "deflate"))]
54 deflate_quality: Level,
55}
56
57impl Default for CompressMiddleware {
58 fn default() -> Self {
59 CompressMiddlewareBuilder::default().into()
60 }
61}
62
63impl CompressMiddleware {
64 pub fn new() -> Self {
79 Self::default()
80 }
81
82 pub fn builder() -> CompressMiddlewareBuilder {
86 CompressMiddlewareBuilder::new()
87 }
88
89 pub fn set_threshold(&mut self, threshold: usize) {
91 self.threshold = threshold
92 }
93
94 pub fn threshold(&self) -> usize {
96 self.threshold
97 }
98
99 #[cfg(feature = "regex-check")]
100 pub fn set_content_type_check(&mut self, content_type_check: Option<Regex>) {
102 self.content_type_check = content_type_check
103 }
104
105 #[cfg(feature = "regex-check")]
106 pub fn content_type_check(&self) -> Option<&Regex> {
108 self.content_type_check.as_ref()
109 }
110}
111
112#[tide::utils::async_trait]
113impl<State: Clone + Send + Sync + 'static> Middleware<State> for CompressMiddleware {
114 async fn handle(&self, req: Request<State>, next: Next<'_, State>) -> tide::Result {
115 let is_head = req.method() == Method::Head;
118 let accepts = AcceptEncoding::from_headers(&req)?;
119
120 let mut res: Response = next.run(req).await;
122
123 if is_head || accepts.is_none() {
126 return Ok(res);
127 }
128 let mut accepts = accepts.expect("checked directly above");
129
130 if let Some(cache_control) = CacheControl::from_headers(&res)? {
132 if cache_control
135 .iter()
136 .any(|directive| directive == &CacheDirective::NoTransform)
137 {
138 return Ok(res);
139 }
140 }
141
142 let mut vary = Vary::new();
144 vary.push(headers::ACCEPT_ENCODING)?;
145 vary.apply(&mut res);
146
147 if let Some(previous_encoding) = ContentEncoding::from_headers(&res)? {
150 if previous_encoding != Encoding::Identity {
151 return Ok(res);
152 }
153 }
154
155 if let Some(body_len) = res.len() {
157 if body_len < self.threshold {
158 return Ok(res);
159 }
160 }
161
162 #[cfg(feature = "regex-check")]
163 if let Some(ref content_type_check) = self.content_type_check {
165 if let Some(content_type) = ContentType::from_headers(&res)? {
166 if let Some(extension_match) = self
167 .extract_type_regex
168 .captures(content_type.value().as_str())
169 .and_then(|captures| captures.get(1))
170 {
171 #[cfg(feature = "db-check")]
172 if !crate::codegen_database::MIME_DB.contains(extension_match.as_str())
176 && !content_type_check.is_match(extension_match.as_str())
177 {
178 return Ok(res);
179 }
180 #[cfg(not(feature = "db-check"))]
181 if !content_type_check.is_match(extension_match.as_str()) {
182 return Ok(res);
183 }
184 }
185 }
186 }
187
188 let encoding = accepts.negotiate(&[
189 #[cfg(feature = "brotli")]
190 Encoding::Brotli,
191 #[cfg(feature = "gzip")]
192 Encoding::Gzip,
193 #[cfg(feature = "deflate")]
194 Encoding::Deflate,
195 Encoding::Identity, ])?;
197
198 if encoding == Encoding::Identity {
200 res.remove_header(headers::CONTENT_ENCODING);
201 return Ok(res);
202 }
203
204 let body = res.take_body();
205 res.set_body(get_encoder(
207 body,
208 &encoding,
209 #[cfg(feature = "brotli")]
210 self.brotli_quality,
211 #[cfg(any(feature = "gzip", feature = "deflate"))]
212 self.deflate_quality,
213 ));
214 encoding.apply(&mut res);
215
216 res.remove_header(headers::CONTENT_LENGTH);
218
219 Ok(res)
220 }
221}
222
223#[cfg_attr(
225 not(any(feature = "brotli", feature = "deflate", feature = "gzip")),
226 allow(unused_variables)
227)]
228fn get_encoder(
229 body: Body,
230 encoding: &ContentEncoding,
231 #[cfg(feature = "brotli")] brotli_quality: Level,
232 #[cfg(any(feature = "gzip", feature = "deflate"))] deflate_quality: Level,
233) -> Body {
234 #[cfg(feature = "brotli")]
235 {
236 if *encoding == Encoding::Brotli {
237 return Body::from_reader(
238 BufReader::new(BrotliEncoder::with_quality(body, brotli_quality)),
239 None,
240 );
241 }
242 }
243
244 #[cfg(feature = "gzip")]
245 {
246 if *encoding == Encoding::Gzip {
247 return Body::from_reader(
248 BufReader::new(GzipEncoder::with_quality(body, deflate_quality)),
249 None,
250 );
251 }
252 }
253
254 #[cfg(feature = "deflate")]
255 {
256 if *encoding == Encoding::Deflate {
257 return Body::from_reader(
258 BufReader::new(DeflateEncoder::with_quality(body, deflate_quality)),
259 None,
260 );
261 }
262 }
263
264 body
265}
266
267#[derive(Clone, Debug)]
268pub struct CompressMiddlewareBuilder {
292 pub threshold: usize,
294 #[cfg(feature = "regex-check")]
295 pub content_type_check: Option<Regex>,
297 #[cfg(feature = "brotli")]
298 pub brotli_quality: Level,
300 #[cfg(any(feature = "gzip", feature = "deflate"))]
301 pub deflate_quality: Level,
303}
304
305impl Default for CompressMiddlewareBuilder {
306 fn default() -> Self {
307 Self {
308 threshold: THRESHOLD,
309 #[cfg(feature = "regex-check")]
310 content_type_check: Some(
311 RegexBuilder::new(CONTENT_TYPE_CHECK_PATTERN)
312 .case_insensitive(true)
313 .build()
314 .expect("Constant regular expression defined in Tide-Compress's source code"),
315 ),
316 #[cfg(feature = "brotli")]
317 brotli_quality: Level::Fastest,
318 #[cfg(any(feature = "gzip", feature = "deflate"))]
319 deflate_quality: Level::Default,
320 }
321 }
322}
323
324impl CompressMiddlewareBuilder {
325 pub fn new() -> Self {
328 Self::default()
329 }
330
331 pub fn threshold(mut self, threshold: usize) -> Self {
333 self.threshold = threshold;
334 self
335 }
336
337 #[cfg(feature = "regex-check")]
338 pub fn content_type_check(mut self, content_type_check: Option<Regex>) -> Self {
340 self.content_type_check = content_type_check;
341 self
342 }
343
344 #[cfg(feature = "brotli")]
345 pub fn brotli_quality(mut self, quality: Level) -> Self {
347 self.brotli_quality = quality;
348 self
349 }
350
351 #[cfg(any(feature = "gzip", feature = "deflate"))]
352 pub fn deflate_quality(mut self, quality: Level) -> Self {
354 self.deflate_quality = quality;
355 self
356 }
357
358 pub fn build(self) -> CompressMiddleware {
360 self.into()
361 }
362}
363
364impl From<CompressMiddlewareBuilder> for CompressMiddleware {
365 fn from(builder: CompressMiddlewareBuilder) -> Self {
366 Self {
367 threshold: builder.threshold,
368 #[cfg(feature = "regex-check")]
369 content_type_check: builder.content_type_check,
370 #[cfg(feature = "regex-check")]
371 extract_type_regex: Regex::new(EXTRACT_TYPE_PATTERN)
372 .expect("Constant regular expression defined in Tide-Compress's source code"),
373 #[cfg(feature = "brotli")]
374 brotli_quality: builder.brotli_quality,
375 #[cfg(any(feature = "gzip", feature = "deflate"))]
376 deflate_quality: builder.deflate_quality,
377 }
378 }
379}