poem/middleware/
compression.rs

1use std::{collections::HashSet, str::FromStr};
2
3use headers::HeaderMap;
4
5use crate::{
6    Body, Endpoint, IntoResponse, Middleware, Request, Response, Result,
7    http::header,
8    web::{Compress, CompressionAlgo, CompressionLevel},
9};
10
11enum ContentCoding {
12    Brotli,
13    Deflate,
14    Gzip,
15    Star,
16    Zstd,
17}
18
19impl FromStr for ContentCoding {
20    type Err = ();
21
22    fn from_str(s: &str) -> Result<Self, Self::Err> {
23        if s.eq_ignore_ascii_case("deflate") {
24            Ok(ContentCoding::Deflate)
25        } else if s.eq_ignore_ascii_case("gzip") {
26            Ok(ContentCoding::Gzip)
27        } else if s.eq_ignore_ascii_case("br") {
28            Ok(ContentCoding::Brotli)
29        } else if s == "*" {
30            Ok(ContentCoding::Star)
31        } else if s == "zsdt" {
32            Ok(ContentCoding::Zstd)
33        } else {
34            Err(())
35        }
36    }
37}
38
39fn parse_accept_encoding(
40    headers: &HeaderMap,
41    enabled_algorithms: &HashSet<CompressionAlgo>,
42) -> Option<ContentCoding> {
43    headers
44        .get_all(header::ACCEPT_ENCODING)
45        .iter()
46        .filter_map(|hval| hval.to_str().ok())
47        .flat_map(|s| s.split(',').map(str::trim))
48        .filter_map(|v| {
49            let (e, q) = match v.split_once(";q=") {
50                Some((e, q)) => (e, (q.parse::<f32>().ok()? * 1000.0) as i32),
51                None => (v, 1000),
52            };
53            let coding: ContentCoding = e.parse().ok()?;
54            Some((coding, q))
55        })
56        .filter(|(encoding, _)| {
57            if !enabled_algorithms.is_empty() {
58                match encoding {
59                    ContentCoding::Brotli => enabled_algorithms.contains(&CompressionAlgo::BR),
60                    ContentCoding::Deflate => {
61                        enabled_algorithms.contains(&CompressionAlgo::DEFLATE)
62                    }
63                    ContentCoding::Gzip => enabled_algorithms.contains(&CompressionAlgo::GZIP),
64                    ContentCoding::Zstd => enabled_algorithms.contains(&CompressionAlgo::ZSTD),
65                    _ => true,
66                }
67            } else {
68                true
69            }
70        })
71        .max_by_key(|(coding, q)| (*q, coding_priority(coding)))
72        .map(|(coding, _)| coding)
73}
74
75/// Middleware to decompress the request body and compress the response body.
76///
77/// The decompression algorithm is selected according to the request
78/// `Content-Encoding` header, and the compression algorithm is selected
79/// according to the request `Accept-Encoding` header.
80#[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
81#[derive(Default)]
82pub struct Compression {
83    level: Option<CompressionLevel>,
84    algorithms: HashSet<CompressionAlgo>,
85}
86
87impl Compression {
88    /// Creates a new `Compression` middleware.
89    #[must_use]
90    pub fn new() -> Self {
91        Self::default()
92    }
93
94    /// Specify the compression level
95    #[must_use]
96    #[inline]
97    pub fn with_quality(self, level: CompressionLevel) -> Self {
98        Self {
99            level: Some(level),
100            ..self
101        }
102    }
103
104    /// Specify the enabled algorithms (defaults to all)
105    #[must_use]
106    #[inline]
107    pub fn algorithms(self, algorithms: impl IntoIterator<Item = CompressionAlgo>) -> Self {
108        Self {
109            algorithms: algorithms.into_iter().collect(),
110            ..self
111        }
112    }
113}
114
115impl<E: Endpoint> Middleware<E> for Compression {
116    type Output = CompressionEndpoint<E>;
117
118    fn transform(&self, ep: E) -> Self::Output {
119        CompressionEndpoint {
120            ep,
121            level: self.level,
122            algorithms: self.algorithms.clone(),
123        }
124    }
125}
126
127/// Endpoint for the Compression middleware.
128#[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
129pub struct CompressionEndpoint<E: Endpoint> {
130    ep: E,
131    level: Option<CompressionLevel>,
132    algorithms: HashSet<CompressionAlgo>,
133}
134
135#[inline]
136fn coding_priority(c: &ContentCoding) -> u8 {
137    match *c {
138        ContentCoding::Deflate => 1,
139        ContentCoding::Gzip => 2,
140        ContentCoding::Brotli => 3,
141        ContentCoding::Zstd => 4,
142        _ => 0,
143    }
144}
145
146impl<E: Endpoint> Endpoint for CompressionEndpoint<E> {
147    type Output = Response;
148
149    async fn call(&self, mut req: Request) -> Result<Self::Output> {
150        // decompress request body
151        if let Some(algo) = req
152            .headers()
153            .get(header::CONTENT_ENCODING)
154            .and_then(|value| value.to_str().ok())
155            .and_then(|value| CompressionAlgo::from_str(value).ok())
156        {
157            let new_body = algo.decompress(req.take_body().into_async_read());
158            req.set_body(Body::from_async_read(new_body));
159        }
160
161        // negotiate content-encoding
162        let compress_algo =
163            parse_accept_encoding(req.headers(), &self.algorithms).map(|coding| match coding {
164                ContentCoding::Gzip => CompressionAlgo::GZIP,
165                ContentCoding::Deflate => CompressionAlgo::DEFLATE,
166                ContentCoding::Brotli => CompressionAlgo::BR,
167                ContentCoding::Star | ContentCoding::Zstd => CompressionAlgo::ZSTD,
168            });
169
170        let resp = self.ep.call(req).await?;
171        match compress_algo {
172            Some(algo) => {
173                let mut compress = Compress::new(resp, algo);
174                if let Some(level) = self.level {
175                    compress = compress.with_quality(level);
176                }
177                Ok(compress.into_response())
178            }
179            None => Ok(resp.into_response()),
180        }
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use tokio::io::AsyncReadExt;
187
188    use super::*;
189    use crate::{EndpointExt, handler, test::TestClient};
190
191    const DATA: &str = "abcdefghijklmnopqrstuvwxyz1234567890";
192    const DATA_REV: &str = "0987654321zyxwvutsrqponmlkjihgfedcba";
193
194    #[handler(internal)]
195    async fn index(data: String) -> String {
196        String::from_utf8(data.into_bytes().into_iter().rev().collect()).unwrap()
197    }
198
199    async fn test_algo(algo: CompressionAlgo) {
200        let ep = index.with(Compression::default());
201        let cli = TestClient::new(ep);
202
203        let resp = cli
204            .post("/")
205            .header("Content-Encoding", algo.as_str())
206            .header("Accept-Encoding", algo.as_str())
207            .body(Body::from_async_read(algo.compress(DATA.as_bytes(), None)))
208            .send()
209            .await;
210
211        resp.assert_status_is_ok();
212        resp.assert_header("Content-Encoding", algo.as_str());
213
214        let mut data = Vec::new();
215        let mut reader = algo.decompress(resp.0.into_body().into_async_read());
216        reader.read_to_end(&mut data).await.unwrap();
217        assert_eq!(data, DATA_REV.as_bytes());
218    }
219
220    #[tokio::test]
221    async fn test_compression() {
222        test_algo(CompressionAlgo::BR).await;
223        test_algo(CompressionAlgo::DEFLATE).await;
224        test_algo(CompressionAlgo::GZIP).await;
225    }
226
227    #[tokio::test]
228    async fn test_negotiate() {
229        let ep = index.with(Compression::default());
230        let cli = TestClient::new(ep);
231
232        let resp = cli
233            .post("/")
234            .header("Accept-Encoding", "identity; q=0.5, gzip;q=1.0, br;q=0.3")
235            .body(DATA)
236            .send()
237            .await;
238        resp.assert_status_is_ok();
239        resp.assert_header("Content-Encoding", "gzip");
240
241        let mut data = Vec::new();
242        let mut reader = CompressionAlgo::GZIP.decompress(resp.0.into_body().into_async_read());
243        reader.read_to_end(&mut data).await.unwrap();
244        assert_eq!(data, DATA_REV.as_bytes());
245    }
246
247    #[tokio::test]
248    async fn test_star() {
249        let ep = index.with(Compression::default());
250        let cli = TestClient::new(ep);
251
252        let resp = cli
253            .post("/")
254            .header("Accept-Encoding", "identity; q=0.5, *;q=1.0, br;q=0.3")
255            .body(DATA)
256            .send()
257            .await;
258        resp.assert_status_is_ok();
259        resp.assert_header("Content-Encoding", "zstd");
260
261        let mut data = Vec::new();
262        let mut reader = CompressionAlgo::ZSTD.decompress(resp.0.into_body().into_async_read());
263        reader.read_to_end(&mut data).await.unwrap();
264        assert_eq!(data, DATA_REV.as_bytes());
265    }
266
267    #[tokio::test]
268    async fn test_coding_priority() {
269        let ep = index.with(Compression::default());
270        let cli = TestClient::new(ep);
271
272        let resp = cli
273            .post("/")
274            .header("Accept-Encoding", "gzip, deflate, br")
275            .body(DATA)
276            .send()
277            .await;
278        resp.assert_status_is_ok();
279        resp.assert_header("Content-Encoding", "br");
280
281        let mut data = Vec::new();
282        let mut reader = CompressionAlgo::BR.decompress(resp.0.into_body().into_async_read());
283        reader.read_to_end(&mut data).await.unwrap();
284        assert_eq!(data, DATA_REV.as_bytes());
285    }
286
287    #[tokio::test]
288    async fn test_enabled_algorithms() {
289        let ep = index.with(Compression::default().algorithms([CompressionAlgo::GZIP]));
290        let cli = TestClient::new(ep);
291
292        let resp = cli
293            .post("/")
294            .header("Accept-Encoding", "gzip, deflate, br")
295            .body(DATA)
296            .send()
297            .await;
298        resp.assert_status_is_ok();
299        resp.assert_header("Content-Encoding", "gzip");
300
301        let ep = index.with(Compression::default().algorithms([CompressionAlgo::BR]));
302        let cli = TestClient::new(ep);
303
304        let resp = cli
305            .post("/")
306            .header("Accept-Encoding", "gzip, deflate, br")
307            .body(DATA)
308            .send()
309            .await;
310        resp.assert_status_is_ok();
311        resp.assert_header("Content-Encoding", "br");
312    }
313}