rama_http/layer/compression/
mod.rs1pub mod predicate;
77
78pub(crate) mod body;
79mod layer;
80mod pin_project_cfg;
81mod service;
82
83#[doc(inline)]
84pub use self::{
85 body::CompressionBody,
86 layer::CompressionLayer,
87 predicate::{DefaultPredicate, Predicate},
88 service::Compression,
89};
90#[doc(inline)]
91pub use crate::layer::util::compression::CompressionLevel;
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96
97 use crate::layer::compression::predicate::SizeAbove;
98
99 use crate::dep::http_body::Body as _;
100 use crate::dep::http_body_util::BodyExt;
101 use crate::header::{
102 ACCEPT_ENCODING, ACCEPT_RANGES, CONTENT_ENCODING, CONTENT_RANGE, CONTENT_TYPE, RANGE,
103 };
104 use crate::{Body, HeaderValue, Request, Response};
105 use async_compression::tokio::write::{BrotliDecoder, BrotliEncoder};
106 use flate2::read::GzDecoder;
107 use rama_core::service::service_fn;
108 use rama_core::{Context, Service};
109 use std::convert::Infallible;
110 use std::io::Read;
111 use std::sync::{Arc, RwLock};
112 use tokio::io::{AsyncReadExt, AsyncWriteExt};
113 use tokio_util::io::StreamReader;
114
115 #[derive(Clone)]
117 struct Always;
118
119 impl Predicate for Always {
120 fn should_compress<B>(&self, _: &rama_http_types::Response<B>) -> bool
121 where
122 B: rama_http_types::dep::http_body::Body,
123 {
124 true
125 }
126 }
127
128 #[tokio::test]
129 async fn gzip_works() {
130 let svc = service_fn(handle);
131 let svc = Compression::new(svc).compress_when(Always);
132
133 let req = Request::builder()
135 .header("accept-encoding", "gzip")
136 .body(Body::empty())
137 .unwrap();
138 let res = svc.serve(Context::default(), req).await.unwrap();
139
140 let collected = res.into_body().collect().await.unwrap();
142 let compressed_data = collected.to_bytes();
143
144 let mut decoder = GzDecoder::new(&compressed_data[..]);
148 let mut decompressed = String::new();
149 decoder.read_to_string(&mut decompressed).unwrap();
150
151 assert_eq!(decompressed, "Hello, World!");
152 }
153
154 #[tokio::test]
155 async fn x_gzip_works() {
156 let svc = service_fn(handle);
157 let svc = Compression::new(svc).compress_when(Always);
158
159 let req = Request::builder()
161 .header("accept-encoding", "x-gzip")
162 .body(Body::empty())
163 .unwrap();
164 let res = svc.serve(Context::default(), req).await.unwrap();
165
166 assert_eq!(
169 res.headers()
170 .get_all("content-encoding")
171 .iter()
172 .collect::<Vec<&HeaderValue>>(),
173 vec!(HeaderValue::from_static("gzip"))
174 );
175
176 let collected = res.into_body().collect().await.unwrap();
178 let compressed_data = collected.to_bytes();
179
180 let mut decoder = GzDecoder::new(&compressed_data[..]);
184 let mut decompressed = String::new();
185 decoder.read_to_string(&mut decompressed).unwrap();
186
187 assert_eq!(decompressed, "Hello, World!");
188 }
189
190 #[tokio::test]
191 async fn zstd_works() {
192 let svc = service_fn(handle);
193 let svc = Compression::new(svc).compress_when(Always);
194
195 let req = Request::builder()
197 .header("accept-encoding", "zstd")
198 .body(Body::empty())
199 .unwrap();
200 let res = svc.serve(Context::default(), req).await.unwrap();
201
202 let body = res.into_body();
204 let compressed_data = body.collect().await.unwrap().to_bytes();
205
206 let decompressed = zstd::stream::decode_all(std::io::Cursor::new(compressed_data)).unwrap();
208 let decompressed = String::from_utf8(decompressed).unwrap();
209
210 assert_eq!(decompressed, "Hello, World!");
211 }
212
213 #[tokio::test]
214 async fn no_recompress() {
215 const DATA: &str = "Hello, World! I'm already compressed with br!";
216
217 let svc = service_fn(async |_| {
218 let buf = {
219 let mut buf = Vec::new();
220
221 let mut enc = BrotliEncoder::new(&mut buf);
222 enc.write_all(DATA.as_bytes()).await?;
223 enc.flush().await?;
224 buf
225 };
226
227 let resp = Response::builder()
228 .header("content-encoding", "br")
229 .body(Body::from(buf))
230 .unwrap();
231 Ok::<_, std::io::Error>(resp)
232 });
233 let svc = Compression::new(svc);
234
235 let req = Request::builder()
240 .header("accept-encoding", "gzip")
241 .body(Body::empty())
242 .unwrap();
243 let res = svc.serve(Context::default(), req).await.unwrap();
244
245 assert_eq!(
247 res.headers()
248 .get("content-encoding")
249 .and_then(|h| h.to_str().ok())
250 .unwrap_or_default(),
251 "br",
252 );
253
254 let body = res.into_body();
256 let data = body.collect().await.unwrap().to_bytes();
257
258 let data = {
260 let mut output_buf = Vec::new();
261 let mut decoder = BrotliDecoder::new(&mut output_buf);
262 decoder
263 .write_all(&data)
264 .await
265 .expect("couldn't brotli-decode");
266 decoder.flush().await.expect("couldn't flush");
267 output_buf
268 };
269
270 assert_eq!(data, DATA.as_bytes());
271 }
272
273 async fn handle(_req: Request) -> Result<Response, Infallible> {
274 let body = Body::from("Hello, World!");
275 Ok(Response::builder().body(body).unwrap())
276 }
277
278 #[tokio::test]
279 async fn will_not_compress_if_filtered_out() {
280 use predicate::Predicate;
281
282 const DATA: &str = "Hello world uncompressed";
283
284 let svc_fn = service_fn(async |_| {
285 let resp = Response::builder()
286 .body(Body::from(DATA.as_bytes()))
288 .unwrap();
289 Ok::<_, std::io::Error>(resp)
290 });
291
292 #[derive(Default, Clone)]
294 struct EveryOtherResponse(Arc<RwLock<u64>>);
295
296 #[allow(clippy::dbg_macro)]
297 impl Predicate for EveryOtherResponse {
298 fn should_compress<B>(&self, _: &rama_http_types::Response<B>) -> bool
299 where
300 B: rama_http_types::dep::http_body::Body,
301 {
302 let mut guard = self.0.write().unwrap();
303 let should_compress = *guard % 2 != 0;
304 *guard += 1;
305 should_compress
306 }
307 }
308
309 let svc = Compression::new(svc_fn).compress_when(EveryOtherResponse::default());
310 let req = Request::builder()
311 .header("accept-encoding", "br")
312 .body(Body::empty())
313 .unwrap();
314 let res = svc.serve(Context::default(), req).await.unwrap();
315
316 let body = res.into_body();
318 let data = body.collect().await.unwrap().to_bytes();
319 let still_uncompressed = String::from_utf8(data.to_vec()).unwrap();
320 assert_eq!(DATA, &still_uncompressed);
321
322 let req = Request::builder()
324 .header("accept-encoding", "br")
325 .body(Body::empty())
326 .unwrap();
327 let res = svc.serve(Context::default(), req).await.unwrap();
328
329 let body = res.into_body();
331 let data = body.collect().await.unwrap().to_bytes();
332 assert!(String::from_utf8(data.to_vec()).is_err());
333 }
334
335 #[tokio::test]
336 async fn doesnt_compress_images() {
337 async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
338 let mut res = Response::new(Body::from(
339 "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
340 ));
341 res.headers_mut()
342 .insert(CONTENT_TYPE, "image/png".parse().unwrap());
343 Ok(res)
344 }
345
346 let svc = Compression::new(service_fn(handle));
347
348 let res = svc
349 .serve(
350 Context::default(),
351 Request::builder()
352 .header(ACCEPT_ENCODING, "gzip")
353 .body(Body::empty())
354 .unwrap(),
355 )
356 .await
357 .unwrap();
358 assert!(res.headers().get(CONTENT_ENCODING).is_none());
359 }
360
361 #[tokio::test]
362 async fn does_compress_svg() {
363 async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
364 let mut res = Response::new(Body::from(
365 "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
366 ));
367 res.headers_mut()
368 .insert(CONTENT_TYPE, "image/svg+xml".parse().unwrap());
369 Ok(res)
370 }
371
372 let svc = Compression::new(service_fn(handle));
373
374 let res = svc
375 .serve(
376 Context::default(),
377 Request::builder()
378 .header(ACCEPT_ENCODING, "gzip")
379 .body(Body::empty())
380 .unwrap(),
381 )
382 .await
383 .unwrap();
384 assert_eq!(res.headers()[CONTENT_ENCODING], "gzip");
385 }
386
387 #[tokio::test]
388 async fn compress_with_quality() {
389 const DATA: &str = "Check compression quality level! Check compression quality level! Check compression quality level!";
390 let level = CompressionLevel::Best;
391
392 let svc = service_fn(async |_| {
393 let resp = Response::builder()
394 .body(Body::from(DATA.as_bytes()))
395 .unwrap();
396 Ok::<_, std::io::Error>(resp)
397 });
398
399 let svc = Compression::new(svc).quality(level);
400
401 let req = Request::builder()
403 .header("accept-encoding", "br")
404 .body(Body::empty())
405 .unwrap();
406 let res = svc.serve(Context::default(), req).await.unwrap();
407
408 let body = res.into_body();
410 let compressed_data = body.collect().await.unwrap().to_bytes();
411
412 let compressed_with_level = {
414 use async_compression::tokio::bufread::BrotliEncoder;
415
416 let stream = Box::pin(futures_lite::stream::once({
417 Ok::<_, std::io::Error>(DATA.as_bytes())
418 }));
419 let reader = StreamReader::new(stream);
420 let mut enc = BrotliEncoder::with_quality(reader, level.into_async_compression());
421
422 let mut buf = Vec::new();
423 enc.read_to_end(&mut buf).await.unwrap();
424 buf
425 };
426
427 assert_eq!(
428 compressed_data,
429 compressed_with_level.as_slice(),
430 "Compression level is not respected"
431 );
432 }
433
434 #[tokio::test]
435 async fn should_not_compress_ranges() {
436 let svc = service_fn(async |_| {
437 let mut res = Response::new(Body::from("Hello"));
438 let headers = res.headers_mut();
439 headers.insert(ACCEPT_RANGES, "bytes".parse().unwrap());
440 headers.insert(CONTENT_RANGE, "bytes 0-4/*".parse().unwrap());
441 Ok::<_, std::io::Error>(res)
442 });
443 let svc = Compression::new(svc).compress_when(Always);
444
445 let req = Request::builder()
447 .header(ACCEPT_ENCODING, "gzip")
448 .header(RANGE, "bytes=0-4")
449 .body(Body::empty())
450 .unwrap();
451 let res = svc.serve(Context::default(), req).await.unwrap();
452 let headers = res.headers().clone();
453
454 let collected = res.into_body().collect().await.unwrap().to_bytes();
456
457 assert_eq!(headers[ACCEPT_RANGES], "bytes");
458 assert!(!headers.contains_key(CONTENT_ENCODING));
459 assert_eq!(collected, "Hello");
460 }
461
462 #[tokio::test]
463 async fn should_strip_accept_ranges_header_when_compressing() {
464 let svc = service_fn(async |_| {
465 let mut res = Response::new(Body::from("Hello, World!"));
466 res.headers_mut()
467 .insert(ACCEPT_RANGES, "bytes".parse().unwrap());
468 Ok::<_, std::io::Error>(res)
469 });
470 let svc = Compression::new(svc).compress_when(Always);
471
472 let req = Request::builder()
474 .header(ACCEPT_ENCODING, "gzip")
475 .body(Body::empty())
476 .unwrap();
477 let res = svc.serve(Context::default(), req).await.unwrap();
478 let headers = res.headers().clone();
479
480 let collected = res.into_body().collect().await.unwrap();
482 let compressed_data = collected.to_bytes();
483
484 let mut decoder = GzDecoder::new(&compressed_data[..]);
488 let mut decompressed = String::new();
489 decoder.read_to_string(&mut decompressed).unwrap();
490
491 assert!(!headers.contains_key(ACCEPT_RANGES));
492 assert_eq!(headers[CONTENT_ENCODING], "gzip");
493 assert_eq!(decompressed, "Hello, World!");
494 }
495
496 #[tokio::test]
497 async fn size_hint_identity() {
498 const MSG: &str = "Hello, world!";
499 let svc = service_fn(async |_| Ok::<_, std::io::Error>(Response::new(Body::from(MSG))));
500 let svc = Compression::new(svc);
501
502 let req = Request::new(Body::empty());
503 let res = svc.serve(Context::default(), req).await.unwrap();
504 let body = res.into_body();
505 assert_eq!(body.size_hint().exact().unwrap(), MSG.len() as u64);
506 }
507}