poem/web/
compress.rs

1use std::{
2    fmt::{self, Display, Formatter},
3    pin::Pin,
4    str::FromStr,
5};
6
7use tokio::io::{AsyncRead, BufReader};
8
9use crate::{
10    Body, IntoResponse, Response,
11    http::{HeaderValue, header},
12    web::CompressionLevel,
13};
14
15/// The compression algorithms.
16#[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
17#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
18pub enum CompressionAlgo {
19    /// brotli
20    BR,
21    /// deflate
22    DEFLATE,
23    /// gzip
24    GZIP,
25    /// Zstandard
26    ZSTD,
27}
28
29impl FromStr for CompressionAlgo {
30    type Err = ();
31
32    fn from_str(s: &str) -> std::prelude::rust_2015::Result<Self, Self::Err> {
33        Ok(match s {
34            "br" => CompressionAlgo::BR,
35            "deflate" => CompressionAlgo::DEFLATE,
36            "gzip" => CompressionAlgo::GZIP,
37            "zstd" => CompressionAlgo::ZSTD,
38            _ => return Err(()),
39        })
40    }
41}
42
43impl CompressionAlgo {
44    #[inline]
45    pub(crate) fn as_str(&self) -> &'static str {
46        match self {
47            CompressionAlgo::BR => "br",
48            CompressionAlgo::DEFLATE => "deflate",
49            CompressionAlgo::GZIP => "gzip",
50            CompressionAlgo::ZSTD => "zstd",
51        }
52    }
53
54    pub(crate) fn compress<'a>(
55        &self,
56        reader: impl AsyncRead + Send + Unpin + 'a,
57        level: Option<CompressionLevel>,
58    ) -> Pin<Box<dyn AsyncRead + Send + 'a>> {
59        match self {
60            CompressionAlgo::BR => Box::pin(
61                async_compression::tokio::bufread::BrotliEncoder::with_quality(
62                    BufReader::new(reader),
63                    level.unwrap_or(CompressionLevel::Fastest),
64                ),
65            ),
66            CompressionAlgo::DEFLATE => Box::pin(
67                async_compression::tokio::bufread::DeflateEncoder::with_quality(
68                    BufReader::new(reader),
69                    level.unwrap_or(CompressionLevel::Default),
70                ),
71            ),
72            CompressionAlgo::GZIP => Box::pin(
73                async_compression::tokio::bufread::GzipEncoder::with_quality(
74                    BufReader::new(reader),
75                    level.unwrap_or(CompressionLevel::Default),
76                ),
77            ),
78            CompressionAlgo::ZSTD => Box::pin(
79                async_compression::tokio::bufread::ZstdEncoder::with_quality(
80                    BufReader::new(reader),
81                    level.unwrap_or(CompressionLevel::Default),
82                ),
83            ),
84        }
85    }
86
87    pub(crate) fn decompress<'a>(
88        &self,
89        reader: impl AsyncRead + Send + Unpin + 'a,
90    ) -> Pin<Box<dyn AsyncRead + Send + 'a>> {
91        match self {
92            CompressionAlgo::BR => Box::pin(async_compression::tokio::bufread::BrotliDecoder::new(
93                BufReader::new(reader),
94            )),
95            CompressionAlgo::DEFLATE => Box::pin(
96                async_compression::tokio::bufread::DeflateDecoder::new(BufReader::new(reader)),
97            ),
98            CompressionAlgo::GZIP => Box::pin(async_compression::tokio::bufread::GzipDecoder::new(
99                BufReader::new(reader),
100            )),
101            CompressionAlgo::ZSTD => Box::pin(async_compression::tokio::bufread::ZstdDecoder::new(
102                BufReader::new(reader),
103            )),
104        }
105    }
106}
107
108impl Display for CompressionAlgo {
109    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
110        write!(f, "{}", self.as_str())
111    }
112}
113
114/// Compress the response body with the specified algorithm and set the
115/// `Content-Encoding` header.
116///
117/// # Example
118///
119/// ```
120/// use poem::{
121///     handler,
122///     web::{Compress, CompressionAlgo},
123/// };
124///
125/// #[handler]
126/// fn index() -> Compress<String> {
127///     Compress::new("abcdef".to_string(), CompressionAlgo::GZIP)
128/// }
129/// ```
130#[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
131pub struct Compress<T> {
132    inner: T,
133    algo: CompressionAlgo,
134    level: Option<CompressionLevel>,
135}
136
137impl<T> Compress<T> {
138    /// Create a compressed response using the specified algorithm.
139    pub fn new(inner: T, algo: CompressionAlgo) -> Self {
140        Self {
141            inner,
142            algo,
143            level: None,
144        }
145    }
146
147    /// Specify the compression level
148    #[must_use]
149    #[inline]
150    pub fn with_quality(self, level: CompressionLevel) -> Self {
151        Self {
152            level: Some(level),
153            ..self
154        }
155    }
156}
157
158impl<T: IntoResponse> IntoResponse for Compress<T> {
159    fn into_response(self) -> Response {
160        let mut resp = self.inner.into_response();
161        let body = resp.take_body();
162
163        resp.headers_mut().append(
164            header::CONTENT_ENCODING,
165            HeaderValue::from_static(self.algo.as_str()),
166        );
167        resp.headers_mut().remove(header::CONTENT_LENGTH);
168
169        resp.set_body(Body::from_async_read(
170            self.algo.compress(body.into_async_read(), self.level),
171        ));
172        resp
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use tokio::io::AsyncReadExt;
179
180    use super::*;
181    use crate::{EndpointExt, handler, test::TestClient};
182
183    async fn decompress_data(algo: CompressionAlgo, data: &[u8]) -> String {
184        let mut output = Vec::new();
185
186        let mut dec = algo.decompress(data);
187        dec.read_to_end(&mut output).await.unwrap();
188        String::from_utf8(output).unwrap()
189    }
190
191    async fn test_algo(algo: CompressionAlgo) {
192        const DATA: &str = "abcdefghijklmnopqrstuvwxyz1234567890";
193
194        #[handler(internal)]
195        async fn index() -> &'static str {
196            DATA
197        }
198
199        let resp = TestClient::new(
200            index.and_then(move |resp| async move { Ok(Compress::new(resp, algo)) }),
201        )
202        .get("/")
203        .send()
204        .await;
205        resp.assert_status_is_ok();
206        resp.assert_header(header::CONTENT_ENCODING, algo.as_str());
207        resp.assert_header_is_not_exist(header::CONTENT_LENGTH);
208        assert_eq!(
209            decompress_data(algo, &resp.0.into_body().into_bytes().await.unwrap()).await,
210            DATA
211        );
212    }
213
214    #[tokio::test]
215    async fn test_compress() {
216        test_algo(CompressionAlgo::BR).await;
217        test_algo(CompressionAlgo::DEFLATE).await;
218        test_algo(CompressionAlgo::GZIP).await;
219    }
220}