salvo_compression/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2
3//! Compression middleware for the Salvo web framework.
4//!
5//! Read more: <https://salvo.rs>
6
7use std::fmt::{self, Display, Formatter};
8use std::str::FromStr;
9
10use indexmap::IndexMap;
11
12use salvo_core::http::body::ResBody;
13use salvo_core::http::header::{
14    ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, HeaderValue,
15};
16use salvo_core::http::{self, Mime, StatusCode, mime};
17use salvo_core::{Depot, FlowCtrl, Handler, Request, Response, async_trait};
18
19mod encoder;
20mod stream;
21use encoder::Encoder;
22use stream::EncodeStream;
23
24/// Level of compression data should be compressed with.
25#[non_exhaustive]
26#[derive(Clone, Copy, Default, Debug, Eq, PartialEq)]
27pub enum CompressionLevel {
28    /// Fastest quality of compression, usually produces a bigger size.
29    Fastest,
30    /// Best quality of compression, usually produces the smallest size.
31    Minsize,
32    /// Default quality of compression defined by the selected compression algorithm.
33    #[default]
34    Default,
35    /// Precise quality based on the underlying compression algorithms'
36    /// qualities. The interpretation of this depends on the algorithm chosen
37    /// and the specific implementation backing it.
38    /// Qualities are implicitly clamped to the algorithm's maximum.
39    Precise(u32),
40}
41
42/// CompressionAlgo
43#[derive(Eq, PartialEq, Clone, Copy, Debug, Hash)]
44#[non_exhaustive]
45pub enum CompressionAlgo {
46    /// Compress use Brotli algo.
47    #[cfg(feature = "brotli")]
48    #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
49    Brotli,
50
51    /// Compress use Deflate algo.
52    #[cfg(feature = "deflate")]
53    #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
54    Deflate,
55
56    /// Compress use Gzip algo.
57    #[cfg(feature = "gzip")]
58    #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
59    Gzip,
60
61    /// Compress use Zstd algo.
62    #[cfg(feature = "zstd")]
63    #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
64    Zstd,
65}
66
67impl FromStr for CompressionAlgo {
68    type Err = String;
69
70    fn from_str(s: &str) -> Result<Self, Self::Err> {
71        match s {
72            #[cfg(feature = "brotli")]
73            "br" => Ok(CompressionAlgo::Brotli),
74            #[cfg(feature = "brotli")]
75            "brotli" => Ok(CompressionAlgo::Brotli),
76
77            #[cfg(feature = "deflate")]
78            "deflate" => Ok(CompressionAlgo::Deflate),
79
80            #[cfg(feature = "gzip")]
81            "gzip" => Ok(CompressionAlgo::Gzip),
82
83            #[cfg(feature = "zstd")]
84            "zstd" => Ok(CompressionAlgo::Zstd),
85            _ => Err(format!("unknown compression algorithm: {s}")),
86        }
87    }
88}
89
90impl Display for CompressionAlgo {
91    #[allow(unreachable_patterns)]
92    #[allow(unused_variables)]
93    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
94        match self {
95            #[cfg(feature = "brotli")]
96            CompressionAlgo::Brotli => write!(f, "br"),
97            #[cfg(feature = "deflate")]
98            CompressionAlgo::Deflate => write!(f, "deflate"),
99            #[cfg(feature = "gzip")]
100            CompressionAlgo::Gzip => write!(f, "gzip"),
101            #[cfg(feature = "zstd")]
102            CompressionAlgo::Zstd => write!(f, "zstd"),
103            _ => unreachable!(),
104        }
105    }
106}
107
108impl From<CompressionAlgo> for HeaderValue {
109    #[inline]
110    fn from(algo: CompressionAlgo) -> Self {
111        match algo {
112            #[cfg(feature = "brotli")]
113            CompressionAlgo::Brotli => HeaderValue::from_static("br"),
114            #[cfg(feature = "deflate")]
115            CompressionAlgo::Deflate => HeaderValue::from_static("deflate"),
116            #[cfg(feature = "gzip")]
117            CompressionAlgo::Gzip => HeaderValue::from_static("gzip"),
118            #[cfg(feature = "zstd")]
119            CompressionAlgo::Zstd => HeaderValue::from_static("zstd"),
120        }
121    }
122}
123
124/// Compression
125#[derive(Clone, Debug)]
126#[non_exhaustive]
127pub struct Compression {
128    /// Compression algorithms to use.
129    pub algos: IndexMap<CompressionAlgo, CompressionLevel>,
130    /// Content types to compress.
131    pub content_types: Vec<Mime>,
132    /// Sets minimum compression size, if body is less than this value, no compression.
133    pub min_length: usize,
134    /// Ignore request algorithms order in `Accept-Encoding` header and always server's config.
135    pub force_priority: bool,
136}
137
138impl Default for Compression {
139    fn default() -> Self {
140        #[allow(unused_mut)]
141        let mut algos = IndexMap::new();
142        #[cfg(feature = "zstd")]
143        algos.insert(CompressionAlgo::Zstd, CompressionLevel::Default);
144        #[cfg(feature = "gzip")]
145        algos.insert(CompressionAlgo::Gzip, CompressionLevel::Default);
146        #[cfg(feature = "deflate")]
147        algos.insert(CompressionAlgo::Deflate, CompressionLevel::Default);
148        #[cfg(feature = "brotli")]
149        algos.insert(CompressionAlgo::Brotli, CompressionLevel::Default);
150        Self {
151            algos,
152            content_types: vec![
153                mime::TEXT_STAR,
154                mime::APPLICATION_JAVASCRIPT,
155                mime::APPLICATION_JSON,
156                mime::IMAGE_SVG,
157                "application/wasm".parse().expect("invalid mime type"),
158                "application/xml".parse().expect("invalid mime type"),
159                "application/rss+xml".parse().expect("invalid mime type"),
160            ],
161            min_length: 0,
162            force_priority: false,
163        }
164    }
165}
166
167impl Compression {
168    /// Create a new `Compression`.
169    #[inline]
170    pub fn new() -> Self {
171        Default::default()
172    }
173
174    /// Remove all compression algorithms.
175    #[inline]
176    pub fn disable_all(mut self) -> Self {
177        self.algos.clear();
178        self
179    }
180
181    /// Sets `Compression` with algos.
182    #[cfg(feature = "gzip")]
183    #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
184    #[inline]
185    pub fn enable_gzip(mut self, level: CompressionLevel) -> Self {
186        self.algos.insert(CompressionAlgo::Gzip, level);
187        self
188    }
189    /// Disable gzip compression.
190    #[cfg(feature = "gzip")]
191    #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
192    #[inline]
193    pub fn disable_gzip(mut self) -> Self {
194        self.algos.shift_remove(&CompressionAlgo::Gzip);
195        self
196    }
197    /// Enable zstd compression.
198    #[cfg(feature = "zstd")]
199    #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
200    #[inline]
201    pub fn enable_zstd(mut self, level: CompressionLevel) -> Self {
202        self.algos.insert(CompressionAlgo::Zstd, level);
203        self
204    }
205    /// Disable zstd compression.
206    #[cfg(feature = "zstd")]
207    #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
208    #[inline]
209    pub fn disable_zstd(mut self) -> Self {
210        self.algos.shift_remove(&CompressionAlgo::Zstd);
211        self
212    }
213    /// Enable brotli compression.
214    #[cfg(feature = "brotli")]
215    #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
216    #[inline]
217    pub fn enable_brotli(mut self, level: CompressionLevel) -> Self {
218        self.algos.insert(CompressionAlgo::Brotli, level);
219        self
220    }
221    /// Disable brotli compression.
222    #[cfg(feature = "brotli")]
223    #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
224    #[inline]
225    pub fn disable_brotli(mut self) -> Self {
226        self.algos.shift_remove(&CompressionAlgo::Brotli);
227        self
228    }
229
230    /// Enable deflate compression.
231    #[cfg(feature = "deflate")]
232    #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
233    #[inline]
234    pub fn enable_deflate(mut self, level: CompressionLevel) -> Self {
235        self.algos.insert(CompressionAlgo::Deflate, level);
236        self
237    }
238
239    /// Disable deflate compression.
240    #[cfg(feature = "deflate")]
241    #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
242    #[inline]
243    pub fn disable_deflate(mut self) -> Self {
244        self.algos.shift_remove(&CompressionAlgo::Deflate);
245        self
246    }
247
248    /// Sets minimum compression size, if body is less than this value, no compression
249    /// default is 1kb
250    #[inline]
251    pub fn min_length(mut self, size: usize) -> Self {
252        self.min_length = size;
253        self
254    }
255    /// Sets `Compression` with force_priority.
256    #[inline]
257    pub fn force_priority(mut self, force_priority: bool) -> Self {
258        self.force_priority = force_priority;
259        self
260    }
261
262    /// Sets `Compression` with content types list.
263    #[inline]
264    pub fn content_types(mut self, content_types: &[Mime]) -> Self {
265        self.content_types = content_types.to_vec();
266        self
267    }
268
269    fn negotiate(
270        &self,
271        req: &Request,
272        res: &Response,
273    ) -> Option<(CompressionAlgo, CompressionLevel)> {
274        if req.headers().contains_key(&CONTENT_ENCODING) {
275            return None;
276        }
277
278        if !self.content_types.is_empty() {
279            let content_type = res
280                .headers()
281                .get(CONTENT_TYPE)
282                .and_then(|v| v.to_str().ok())
283                .unwrap_or_default();
284            if content_type.is_empty() {
285                return None;
286            }
287            if let Ok(content_type) = content_type.parse::<Mime>() {
288                if !self.content_types.iter().any(|citem| {
289                    citem.type_() == content_type.type_()
290                        && (citem.subtype() == "*" || citem.subtype() == content_type.subtype())
291                }) {
292                    return None;
293                }
294            } else {
295                return None;
296            }
297        }
298        let header = req
299            .headers()
300            .get(ACCEPT_ENCODING)
301            .and_then(|v| v.to_str().ok())?;
302
303        let accept_algos = http::parse_accept_encoding(header)
304            .into_iter()
305            .filter_map(|(algo, level)| {
306                if let Ok(algo) = algo.parse::<CompressionAlgo>() {
307                    Some((algo, level))
308                } else {
309                    None
310                }
311            })
312            .collect::<Vec<_>>();
313        if self.force_priority {
314            let accept_algos = accept_algos
315                .into_iter()
316                .map(|(algo, _)| algo)
317                .collect::<Vec<_>>();
318            self.algos
319                .iter()
320                .find(|(algo, _level)| accept_algos.contains(algo))
321                .map(|(algo, level)| (*algo, *level))
322        } else {
323            accept_algos
324                .into_iter()
325                .find_map(|(algo, _)| self.algos.get(&algo).map(|level| (algo, *level)))
326        }
327    }
328}
329
330#[async_trait]
331impl Handler for Compression {
332    async fn handle(
333        &self,
334        req: &mut Request,
335        depot: &mut Depot,
336        res: &mut Response,
337        ctrl: &mut FlowCtrl,
338    ) {
339        ctrl.call_next(req, depot, res).await;
340        if ctrl.is_ceased() || res.headers().contains_key(CONTENT_ENCODING) {
341            return;
342        }
343
344        if let Some(code) = res.status_code {
345            if code == StatusCode::SWITCHING_PROTOCOLS || code == StatusCode::NO_CONTENT {
346                return;
347            }
348        }
349
350        match res.take_body() {
351            ResBody::None => {
352                return;
353            }
354            ResBody::Once(bytes) => {
355                if self.min_length > 0 && bytes.len() < self.min_length {
356                    res.body(ResBody::Once(bytes));
357                    return;
358                }
359                match self.negotiate(req, res) {
360                    Some((algo, level)) => {
361                        res.stream(EncodeStream::new(algo, level, Some(bytes)));
362                        res.headers_mut().append(CONTENT_ENCODING, algo.into());
363                    }
364                    None => {
365                        res.body(ResBody::Once(bytes));
366                        return;
367                    }
368                }
369            }
370            ResBody::Chunks(chunks) => {
371                if self.min_length > 0 {
372                    let len: usize = chunks.iter().map(|c| c.len()).sum();
373                    if len < self.min_length {
374                        res.body(ResBody::Chunks(chunks));
375                        return;
376                    }
377                }
378                match self.negotiate(req, res) {
379                    Some((algo, level)) => {
380                        res.stream(EncodeStream::new(algo, level, chunks));
381                        res.headers_mut().append(CONTENT_ENCODING, algo.into());
382                    }
383                    None => {
384                        res.body(ResBody::Chunks(chunks));
385                        return;
386                    }
387                }
388            }
389            ResBody::Hyper(body) => match self.negotiate(req, res) {
390                Some((algo, level)) => {
391                    res.stream(EncodeStream::new(algo, level, body));
392                    res.headers_mut().append(CONTENT_ENCODING, algo.into());
393                }
394                None => {
395                    res.body(ResBody::Hyper(body));
396                    return;
397                }
398            },
399            ResBody::Stream(body) => {
400                let body = body.into_inner();
401                match self.negotiate(req, res) {
402                    Some((algo, level)) => {
403                        res.stream(EncodeStream::new(algo, level, body));
404                        res.headers_mut().append(CONTENT_ENCODING, algo.into());
405                    }
406                    None => {
407                        res.body(ResBody::stream(body));
408                        return;
409                    }
410                }
411            }
412            body => {
413                res.body(body);
414                return;
415            }
416        }
417        res.headers_mut().remove(CONTENT_LENGTH);
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use salvo_core::prelude::*;
424    use salvo_core::test::{ResponseExt, TestClient};
425
426    use super::*;
427
428    #[handler]
429    async fn hello() -> &'static str {
430        "hello"
431    }
432
433    #[tokio::test]
434    async fn test_gzip() {
435        let comp_handler = Compression::new().min_length(1);
436        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
437
438        let mut res = TestClient::get("http://127.0.0.1:5801/hello")
439            .add_header(ACCEPT_ENCODING, "gzip", true)
440            .send(router)
441            .await;
442        assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "gzip");
443        let content = res.take_string().await.unwrap();
444        assert_eq!(content, "hello");
445    }
446
447    #[tokio::test]
448    async fn test_brotli() {
449        let comp_handler = Compression::new().min_length(1);
450        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
451
452        let mut res = TestClient::get("http://127.0.0.1:5801/hello")
453            .add_header(ACCEPT_ENCODING, "br", true)
454            .send(router)
455            .await;
456        assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "br");
457        let content = res.take_string().await.unwrap();
458        assert_eq!(content, "hello");
459    }
460
461    #[tokio::test]
462    async fn test_deflate() {
463        let comp_handler = Compression::new().min_length(1);
464        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
465
466        let mut res = TestClient::get("http://127.0.0.1:5801/hello")
467            .add_header(ACCEPT_ENCODING, "deflate", true)
468            .send(router)
469            .await;
470        assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "deflate");
471        let content = res.take_string().await.unwrap();
472        assert_eq!(content, "hello");
473    }
474}