1use std::{collections::HashSet, str::FromStr};
2
3use headers::HeaderMap;
4
5use crate::{
6 Body, Endpoint, IntoResponse, Middleware, Request, Response, Result,
7 http::header,
8 web::{Compress, CompressionAlgo, CompressionLevel},
9};
10
11enum ContentCoding {
12 Brotli,
13 Deflate,
14 Gzip,
15 Star,
16 Zstd,
17}
18
19impl FromStr for ContentCoding {
20 type Err = ();
21
22 fn from_str(s: &str) -> Result<Self, Self::Err> {
23 if s.eq_ignore_ascii_case("deflate") {
24 Ok(ContentCoding::Deflate)
25 } else if s.eq_ignore_ascii_case("gzip") {
26 Ok(ContentCoding::Gzip)
27 } else if s.eq_ignore_ascii_case("br") {
28 Ok(ContentCoding::Brotli)
29 } else if s == "*" {
30 Ok(ContentCoding::Star)
31 } else if s == "zsdt" {
32 Ok(ContentCoding::Zstd)
33 } else {
34 Err(())
35 }
36 }
37}
38
39fn parse_accept_encoding(
40 headers: &HeaderMap,
41 enabled_algorithms: &HashSet<CompressionAlgo>,
42) -> Option<ContentCoding> {
43 headers
44 .get_all(header::ACCEPT_ENCODING)
45 .iter()
46 .filter_map(|hval| hval.to_str().ok())
47 .flat_map(|s| s.split(',').map(str::trim))
48 .filter_map(|v| {
49 let (e, q) = match v.split_once(";q=") {
50 Some((e, q)) => (e, (q.parse::<f32>().ok()? * 1000.0) as i32),
51 None => (v, 1000),
52 };
53 let coding: ContentCoding = e.parse().ok()?;
54 Some((coding, q))
55 })
56 .filter(|(encoding, _)| {
57 if !enabled_algorithms.is_empty() {
58 match encoding {
59 ContentCoding::Brotli => enabled_algorithms.contains(&CompressionAlgo::BR),
60 ContentCoding::Deflate => {
61 enabled_algorithms.contains(&CompressionAlgo::DEFLATE)
62 }
63 ContentCoding::Gzip => enabled_algorithms.contains(&CompressionAlgo::GZIP),
64 ContentCoding::Zstd => enabled_algorithms.contains(&CompressionAlgo::ZSTD),
65 _ => true,
66 }
67 } else {
68 true
69 }
70 })
71 .max_by_key(|(coding, q)| (*q, coding_priority(coding)))
72 .map(|(coding, _)| coding)
73}
74
75#[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
81#[derive(Default)]
82pub struct Compression {
83 level: Option<CompressionLevel>,
84 algorithms: HashSet<CompressionAlgo>,
85}
86
87impl Compression {
88 #[must_use]
90 pub fn new() -> Self {
91 Self::default()
92 }
93
94 #[must_use]
96 #[inline]
97 pub fn with_quality(self, level: CompressionLevel) -> Self {
98 Self {
99 level: Some(level),
100 ..self
101 }
102 }
103
104 #[must_use]
106 #[inline]
107 pub fn algorithms(self, algorithms: impl IntoIterator<Item = CompressionAlgo>) -> Self {
108 Self {
109 algorithms: algorithms.into_iter().collect(),
110 ..self
111 }
112 }
113}
114
115impl<E: Endpoint> Middleware<E> for Compression {
116 type Output = CompressionEndpoint<E>;
117
118 fn transform(&self, ep: E) -> Self::Output {
119 CompressionEndpoint {
120 ep,
121 level: self.level,
122 algorithms: self.algorithms.clone(),
123 }
124 }
125}
126
127#[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
129pub struct CompressionEndpoint<E: Endpoint> {
130 ep: E,
131 level: Option<CompressionLevel>,
132 algorithms: HashSet<CompressionAlgo>,
133}
134
135#[inline]
136fn coding_priority(c: &ContentCoding) -> u8 {
137 match *c {
138 ContentCoding::Deflate => 1,
139 ContentCoding::Gzip => 2,
140 ContentCoding::Brotli => 3,
141 ContentCoding::Zstd => 4,
142 _ => 0,
143 }
144}
145
146impl<E: Endpoint> Endpoint for CompressionEndpoint<E> {
147 type Output = Response;
148
149 async fn call(&self, mut req: Request) -> Result<Self::Output> {
150 if let Some(algo) = req
152 .headers()
153 .get(header::CONTENT_ENCODING)
154 .and_then(|value| value.to_str().ok())
155 .and_then(|value| CompressionAlgo::from_str(value).ok())
156 {
157 let new_body = algo.decompress(req.take_body().into_async_read());
158 req.set_body(Body::from_async_read(new_body));
159 }
160
161 let compress_algo =
163 parse_accept_encoding(req.headers(), &self.algorithms).map(|coding| match coding {
164 ContentCoding::Gzip => CompressionAlgo::GZIP,
165 ContentCoding::Deflate => CompressionAlgo::DEFLATE,
166 ContentCoding::Brotli => CompressionAlgo::BR,
167 ContentCoding::Star | ContentCoding::Zstd => CompressionAlgo::ZSTD,
168 });
169
170 let resp = self.ep.call(req).await?;
171 match compress_algo {
172 Some(algo) => {
173 let mut compress = Compress::new(resp, algo);
174 if let Some(level) = self.level {
175 compress = compress.with_quality(level);
176 }
177 Ok(compress.into_response())
178 }
179 None => Ok(resp.into_response()),
180 }
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use tokio::io::AsyncReadExt;
187
188 use super::*;
189 use crate::{EndpointExt, handler, test::TestClient};
190
191 const DATA: &str = "abcdefghijklmnopqrstuvwxyz1234567890";
192 const DATA_REV: &str = "0987654321zyxwvutsrqponmlkjihgfedcba";
193
194 #[handler(internal)]
195 async fn index(data: String) -> String {
196 String::from_utf8(data.into_bytes().into_iter().rev().collect()).unwrap()
197 }
198
199 async fn test_algo(algo: CompressionAlgo) {
200 let ep = index.with(Compression::default());
201 let cli = TestClient::new(ep);
202
203 let resp = cli
204 .post("/")
205 .header("Content-Encoding", algo.as_str())
206 .header("Accept-Encoding", algo.as_str())
207 .body(Body::from_async_read(algo.compress(DATA.as_bytes(), None)))
208 .send()
209 .await;
210
211 resp.assert_status_is_ok();
212 resp.assert_header("Content-Encoding", algo.as_str());
213
214 let mut data = Vec::new();
215 let mut reader = algo.decompress(resp.0.into_body().into_async_read());
216 reader.read_to_end(&mut data).await.unwrap();
217 assert_eq!(data, DATA_REV.as_bytes());
218 }
219
220 #[tokio::test]
221 async fn test_compression() {
222 test_algo(CompressionAlgo::BR).await;
223 test_algo(CompressionAlgo::DEFLATE).await;
224 test_algo(CompressionAlgo::GZIP).await;
225 }
226
227 #[tokio::test]
228 async fn test_negotiate() {
229 let ep = index.with(Compression::default());
230 let cli = TestClient::new(ep);
231
232 let resp = cli
233 .post("/")
234 .header("Accept-Encoding", "identity; q=0.5, gzip;q=1.0, br;q=0.3")
235 .body(DATA)
236 .send()
237 .await;
238 resp.assert_status_is_ok();
239 resp.assert_header("Content-Encoding", "gzip");
240
241 let mut data = Vec::new();
242 let mut reader = CompressionAlgo::GZIP.decompress(resp.0.into_body().into_async_read());
243 reader.read_to_end(&mut data).await.unwrap();
244 assert_eq!(data, DATA_REV.as_bytes());
245 }
246
247 #[tokio::test]
248 async fn test_star() {
249 let ep = index.with(Compression::default());
250 let cli = TestClient::new(ep);
251
252 let resp = cli
253 .post("/")
254 .header("Accept-Encoding", "identity; q=0.5, *;q=1.0, br;q=0.3")
255 .body(DATA)
256 .send()
257 .await;
258 resp.assert_status_is_ok();
259 resp.assert_header("Content-Encoding", "zstd");
260
261 let mut data = Vec::new();
262 let mut reader = CompressionAlgo::ZSTD.decompress(resp.0.into_body().into_async_read());
263 reader.read_to_end(&mut data).await.unwrap();
264 assert_eq!(data, DATA_REV.as_bytes());
265 }
266
267 #[tokio::test]
268 async fn test_coding_priority() {
269 let ep = index.with(Compression::default());
270 let cli = TestClient::new(ep);
271
272 let resp = cli
273 .post("/")
274 .header("Accept-Encoding", "gzip, deflate, br")
275 .body(DATA)
276 .send()
277 .await;
278 resp.assert_status_is_ok();
279 resp.assert_header("Content-Encoding", "br");
280
281 let mut data = Vec::new();
282 let mut reader = CompressionAlgo::BR.decompress(resp.0.into_body().into_async_read());
283 reader.read_to_end(&mut data).await.unwrap();
284 assert_eq!(data, DATA_REV.as_bytes());
285 }
286
287 #[tokio::test]
288 async fn test_enabled_algorithms() {
289 let ep = index.with(Compression::default().algorithms([CompressionAlgo::GZIP]));
290 let cli = TestClient::new(ep);
291
292 let resp = cli
293 .post("/")
294 .header("Accept-Encoding", "gzip, deflate, br")
295 .body(DATA)
296 .send()
297 .await;
298 resp.assert_status_is_ok();
299 resp.assert_header("Content-Encoding", "gzip");
300
301 let ep = index.with(Compression::default().algorithms([CompressionAlgo::BR]));
302 let cli = TestClient::new(ep);
303
304 let resp = cli
305 .post("/")
306 .header("Accept-Encoding", "gzip, deflate, br")
307 .body(DATA)
308 .send()
309 .await;
310 resp.assert_status_is_ok();
311 resp.assert_header("Content-Encoding", "br");
312 }
313}