rama_http/layer/compression/
layer.rs1use super::predicate::DefaultPredicate;
2use super::{Compression, Predicate};
3use crate::layer::util::compression::CompressionLevel;
4use rama_core::Layer;
5use rama_http_types::headers::encoding::AcceptEncoding;
6
7#[derive(Clone, Debug, Default)]
14pub struct CompressionLayer<P = DefaultPredicate> {
15 accept: AcceptEncoding,
16 predicate: P,
17 quality: CompressionLevel,
18}
19
20impl<S, P> Layer<S> for CompressionLayer<P>
21where
22 P: Predicate,
23{
24 type Service = Compression<S, P>;
25
26 fn layer(&self, inner: S) -> Self::Service {
27 Compression {
28 inner,
29 accept: self.accept,
30 predicate: self.predicate.clone(),
31 quality: self.quality,
32 }
33 }
34
35 fn into_layer(self, inner: S) -> Self::Service {
36 Compression {
37 inner,
38 accept: self.accept,
39 predicate: self.predicate,
40 quality: self.quality,
41 }
42 }
43}
44
45impl CompressionLayer {
46 pub fn new() -> Self {
48 Self::default()
49 }
50
51 pub fn gzip(mut self, enable: bool) -> Self {
53 self.accept.set_gzip(enable);
54 self
55 }
56
57 pub fn set_gzip(&mut self, enable: bool) -> &mut Self {
59 self.accept.set_gzip(enable);
60 self
61 }
62
63 pub fn deflate(mut self, enable: bool) -> Self {
65 self.accept.set_deflate(enable);
66 self
67 }
68
69 pub fn set_deflate(&mut self, enable: bool) -> &mut Self {
71 self.accept.set_deflate(enable);
72 self
73 }
74
75 pub fn br(mut self, enable: bool) -> Self {
77 self.accept.set_br(enable);
78 self
79 }
80
81 pub fn set_br(&mut self, enable: bool) -> &mut Self {
83 self.accept.set_br(enable);
84 self
85 }
86
87 pub fn zstd(mut self, enable: bool) -> Self {
89 self.accept.set_zstd(enable);
90 self
91 }
92
93 pub fn set_zstd(&mut self, enable: bool) -> &mut Self {
95 self.accept.set_zstd(enable);
96 self
97 }
98
99 pub fn quality(mut self, quality: CompressionLevel) -> Self {
101 self.quality = quality;
102 self
103 }
104
105 pub fn set_quality(&mut self, quality: CompressionLevel) -> &mut Self {
107 self.quality = quality;
108 self
109 }
110
111 pub fn compress_when<C>(self, predicate: C) -> CompressionLayer<C>
115 where
116 C: Predicate,
117 {
118 CompressionLayer {
119 accept: self.accept,
120 predicate,
121 quality: self.quality,
122 }
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 use crate::dep::http_body_util::BodyExt;
131 use crate::{Body, Request, Response, header::ACCEPT_ENCODING};
132 use rama_core::service::service_fn;
133 use rama_core::{Context, Service};
134 use std::convert::Infallible;
135 use tokio::fs::File;
136 use tokio_util::io::ReaderStream;
137
138 async fn handle(_req: Request) -> Result<Response, Infallible> {
139 let file = File::open("Cargo.toml").await.expect("file missing");
141 let stream = ReaderStream::new(file);
143 let body = Body::from_stream(stream);
145 Ok(Response::new(body))
147 }
148
149 #[tokio::test]
150 async fn accept_encoding_configuration_works() -> Result<(), rama_core::error::BoxError> {
151 let deflate_only_layer = CompressionLayer::new()
152 .quality(CompressionLevel::Best)
153 .br(false)
154 .gzip(false);
155
156 let service = deflate_only_layer.into_layer(service_fn(handle));
158
159 let request = Request::builder()
161 .header(ACCEPT_ENCODING, "gzip, deflate, br")
162 .body(Body::empty())?;
163
164 let response = service.serve(Context::default(), request).await?;
165
166 assert_eq!(response.headers()["content-encoding"], "deflate");
167
168 let body = response.into_body();
170 let bytes = body.collect().await.unwrap().to_bytes();
171
172 let deflate_bytes_len = bytes.len();
173
174 let br_only_layer = CompressionLayer::new()
175 .quality(CompressionLevel::Best)
176 .gzip(false)
177 .deflate(false);
178
179 let service = br_only_layer.into_layer(service_fn(handle));
181
182 let request = Request::builder()
184 .header(ACCEPT_ENCODING, "gzip, deflate, br")
185 .body(Body::empty())?;
186
187 let response = service.serve(Context::default(), request).await?;
188
189 assert_eq!(response.headers()["content-encoding"], "br");
190
191 let body = response.into_body();
193 let bytes = body.collect().await.unwrap().to_bytes();
194
195 let br_byte_length = bytes.len();
196
197 assert!(br_byte_length < deflate_bytes_len);
200
201 Ok(())
202 }
203
204 #[tokio::test]
205 async fn zstd_is_web_safe() -> Result<(), rama_core::error::BoxError> {
206 async fn zeroes(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
210 Ok(Response::new(Body::from(vec![0u8; 18_874_368])))
211 }
212 let zstd_layer = CompressionLayer::new()
217 .quality(CompressionLevel::Best)
218 .br(false)
219 .deflate(false)
220 .gzip(false);
221
222 let service = zstd_layer.into_layer(service_fn(zeroes));
223
224 let request = Request::builder()
225 .header(ACCEPT_ENCODING, "zstd")
226 .body(Body::empty())?;
227
228 let response = service.serve(Context::default(), request).await?;
229
230 assert_eq!(response.headers()["content-encoding"], "zstd");
231
232 let body = response.into_body();
233 let bytes = body.collect().await?.to_bytes();
234 let mut dec = zstd::Decoder::new(&*bytes)?;
235 dec.window_log_max(23)?; std::io::copy(&mut dec, &mut std::io::sink())?;
238
239 Ok(())
240 }
241}