1use std::{
2 fmt::{self, Display, Formatter},
3 pin::Pin,
4 str::FromStr,
5};
6
7use tokio::io::{AsyncRead, BufReader};
8
9use crate::{
10 Body, IntoResponse, Response,
11 http::{HeaderValue, header},
12 web::CompressionLevel,
13};
14
15#[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
17#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
18pub enum CompressionAlgo {
19 BR,
21 DEFLATE,
23 GZIP,
25 ZSTD,
27}
28
29impl FromStr for CompressionAlgo {
30 type Err = ();
31
32 fn from_str(s: &str) -> std::prelude::rust_2015::Result<Self, Self::Err> {
33 Ok(match s {
34 "br" => CompressionAlgo::BR,
35 "deflate" => CompressionAlgo::DEFLATE,
36 "gzip" => CompressionAlgo::GZIP,
37 "zstd" => CompressionAlgo::ZSTD,
38 _ => return Err(()),
39 })
40 }
41}
42
43impl CompressionAlgo {
44 #[inline]
45 pub(crate) fn as_str(&self) -> &'static str {
46 match self {
47 CompressionAlgo::BR => "br",
48 CompressionAlgo::DEFLATE => "deflate",
49 CompressionAlgo::GZIP => "gzip",
50 CompressionAlgo::ZSTD => "zstd",
51 }
52 }
53
54 pub(crate) fn compress<'a>(
55 &self,
56 reader: impl AsyncRead + Send + Unpin + 'a,
57 level: Option<CompressionLevel>,
58 ) -> Pin<Box<dyn AsyncRead + Send + 'a>> {
59 match self {
60 CompressionAlgo::BR => Box::pin(
61 async_compression::tokio::bufread::BrotliEncoder::with_quality(
62 BufReader::new(reader),
63 level.unwrap_or(CompressionLevel::Fastest),
64 ),
65 ),
66 CompressionAlgo::DEFLATE => Box::pin(
67 async_compression::tokio::bufread::DeflateEncoder::with_quality(
68 BufReader::new(reader),
69 level.unwrap_or(CompressionLevel::Default),
70 ),
71 ),
72 CompressionAlgo::GZIP => Box::pin(
73 async_compression::tokio::bufread::GzipEncoder::with_quality(
74 BufReader::new(reader),
75 level.unwrap_or(CompressionLevel::Default),
76 ),
77 ),
78 CompressionAlgo::ZSTD => Box::pin(
79 async_compression::tokio::bufread::ZstdEncoder::with_quality(
80 BufReader::new(reader),
81 level.unwrap_or(CompressionLevel::Default),
82 ),
83 ),
84 }
85 }
86
87 pub(crate) fn decompress<'a>(
88 &self,
89 reader: impl AsyncRead + Send + Unpin + 'a,
90 ) -> Pin<Box<dyn AsyncRead + Send + 'a>> {
91 match self {
92 CompressionAlgo::BR => Box::pin(async_compression::tokio::bufread::BrotliDecoder::new(
93 BufReader::new(reader),
94 )),
95 CompressionAlgo::DEFLATE => Box::pin(
96 async_compression::tokio::bufread::DeflateDecoder::new(BufReader::new(reader)),
97 ),
98 CompressionAlgo::GZIP => Box::pin(async_compression::tokio::bufread::GzipDecoder::new(
99 BufReader::new(reader),
100 )),
101 CompressionAlgo::ZSTD => Box::pin(async_compression::tokio::bufread::ZstdDecoder::new(
102 BufReader::new(reader),
103 )),
104 }
105 }
106}
107
108impl Display for CompressionAlgo {
109 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
110 write!(f, "{}", self.as_str())
111 }
112}
113
114#[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
131pub struct Compress<T> {
132 inner: T,
133 algo: CompressionAlgo,
134 level: Option<CompressionLevel>,
135}
136
137impl<T> Compress<T> {
138 pub fn new(inner: T, algo: CompressionAlgo) -> Self {
140 Self {
141 inner,
142 algo,
143 level: None,
144 }
145 }
146
147 #[must_use]
149 #[inline]
150 pub fn with_quality(self, level: CompressionLevel) -> Self {
151 Self {
152 level: Some(level),
153 ..self
154 }
155 }
156}
157
158impl<T: IntoResponse> IntoResponse for Compress<T> {
159 fn into_response(self) -> Response {
160 let mut resp = self.inner.into_response();
161 let body = resp.take_body();
162
163 resp.headers_mut().append(
164 header::CONTENT_ENCODING,
165 HeaderValue::from_static(self.algo.as_str()),
166 );
167 resp.headers_mut().remove(header::CONTENT_LENGTH);
168
169 resp.set_body(Body::from_async_read(
170 self.algo.compress(body.into_async_read(), self.level),
171 ));
172 resp
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use tokio::io::AsyncReadExt;
179
180 use super::*;
181 use crate::{EndpointExt, handler, test::TestClient};
182
183 async fn decompress_data(algo: CompressionAlgo, data: &[u8]) -> String {
184 let mut output = Vec::new();
185
186 let mut dec = algo.decompress(data);
187 dec.read_to_end(&mut output).await.unwrap();
188 String::from_utf8(output).unwrap()
189 }
190
191 async fn test_algo(algo: CompressionAlgo) {
192 const DATA: &str = "abcdefghijklmnopqrstuvwxyz1234567890";
193
194 #[handler(internal)]
195 async fn index() -> &'static str {
196 DATA
197 }
198
199 let resp = TestClient::new(
200 index.and_then(move |resp| async move { Ok(Compress::new(resp, algo)) }),
201 )
202 .get("/")
203 .send()
204 .await;
205 resp.assert_status_is_ok();
206 resp.assert_header(header::CONTENT_ENCODING, algo.as_str());
207 resp.assert_header_is_not_exist(header::CONTENT_LENGTH);
208 assert_eq!(
209 decompress_data(algo, &resp.0.into_body().into_bytes().await.unwrap()).await,
210 DATA
211 );
212 }
213
214 #[tokio::test]
215 async fn test_compress() {
216 test_algo(CompressionAlgo::BR).await;
217 test_algo(CompressionAlgo::DEFLATE).await;
218 test_algo(CompressionAlgo::GZIP).await;
219 }
220}