elif_http/middleware/utils/
compression.rs1use crate::middleware::v2::{Middleware, Next, NextFuture};
7use crate::request::ElifRequest;
8use crate::response::ElifResponse;
9use http_body_util::BodyExt;
10use tower::{Layer, Service};
11use tower_http::compression::{CompressionLayer, CompressionLevel};
12
13#[derive(Debug, Clone)]
15pub struct CompressionConfig {
16 pub level: CompressionLevel,
18 pub enable_gzip: bool,
20 pub enable_brotli: bool,
22 pub enable_deflate: bool,
24}
25
26impl Default for CompressionConfig {
27 fn default() -> Self {
28 Self {
29 level: CompressionLevel::default(),
30 enable_gzip: true,
31 enable_brotli: true,
32 enable_deflate: false, }
34 }
35}
36
37pub struct CompressionMiddleware {
39 layer: CompressionLayer,
40}
41
42impl CompressionMiddleware {
43 pub fn new() -> Self {
45 let config = CompressionConfig::default();
46 Self::with_config(config)
47 }
48
49 pub fn with_config(config: CompressionConfig) -> Self {
51 let mut layer = CompressionLayer::new().quality(config.level);
52
53 if !config.enable_gzip {
55 layer = layer.no_gzip();
56 }
57 if !config.enable_brotli {
58 layer = layer.no_br();
59 }
60 if !config.enable_deflate {
61 layer = layer.no_deflate();
62 }
63
64 Self { layer }
65 }
66
67 pub fn level(self, level: CompressionLevel) -> Self {
69 Self {
70 layer: self.layer.quality(level),
71 }
72 }
73
74 pub fn fast(self) -> Self {
76 self.level(CompressionLevel::Fastest)
77 }
78
79 pub fn best(self) -> Self {
81 self.level(CompressionLevel::Best)
82 }
83
84 pub fn no_gzip(self) -> Self {
86 Self {
87 layer: self.layer.no_gzip(),
88 }
89 }
90
91 pub fn no_brotli(self) -> Self {
93 Self {
94 layer: self.layer.no_br(),
95 }
96 }
97
98 pub fn no_deflate(self) -> Self {
100 Self {
101 layer: self.layer.no_deflate(),
102 }
103 }
104
105 pub fn gzip_only(self) -> Self {
107 Self {
108 layer: self.layer.no_br().no_deflate(),
109 }
110 }
111
112 pub fn brotli_only(self) -> Self {
114 Self {
115 layer: self.layer.no_gzip().no_deflate(),
116 }
117 }
118}
119
120impl std::fmt::Debug for CompressionMiddleware {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 f.debug_struct("CompressionMiddleware")
123 .field("layer", &"<CompressionLayer>")
124 .finish()
125 }
126}
127
128impl Default for CompressionMiddleware {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134impl Clone for CompressionMiddleware {
135 fn clone(&self) -> Self {
136 Self {
137 layer: self.layer.clone(),
138 }
139 }
140}
141
142impl Middleware for CompressionMiddleware {
143 fn handle(&self, request: ElifRequest, next: Next) -> NextFuture<'static> {
144 let layer = self.layer.clone();
145
146 Box::pin(async move {
147 let accept_encoding = request
149 .header("accept-encoding")
150 .and_then(|h| h.to_str().ok())
151 .map(|s| s.to_owned())
152 .unwrap_or_default();
153
154 let wants_compression = accept_encoding.contains("gzip")
155 || accept_encoding.contains("br")
156 || accept_encoding.contains("deflate");
157
158 let response = next.run(request).await;
160
161 if !wants_compression {
162 return response;
164 }
165
166 let axum_response = response.into_axum_response();
167 let (parts, body) = axum_response.into_parts();
168
169 let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
171 Ok(bytes) => bytes,
172 Err(_) => {
173 let response =
175 axum::response::Response::from_parts(parts, axum::body::Body::empty());
176 return ElifResponse::from_axum_response(response).await;
177 }
178 };
179
180 let parts_clone = parts.clone();
182 let body_bytes_clone = body_bytes.clone();
183
184 let mock_request = axum::extract::Request::builder()
186 .uri("/")
187 .header("accept-encoding", &accept_encoding)
188 .body(axum::body::Body::empty())
189 .unwrap();
190
191 let service = tower::service_fn(move |_req: axum::extract::Request| {
193 let response_parts = parts.clone();
194 let response_body = body_bytes.clone();
195 async move {
196 let response = axum::response::Response::from_parts(
197 response_parts,
198 axum::body::Body::from(response_body),
199 );
200 Ok::<axum::response::Response, std::convert::Infallible>(response)
201 }
202 });
203
204 let mut compression_service = layer.layer(service);
206
207 match compression_service.call(mock_request).await {
209 Ok(compressed_response) => {
210 let (compressed_parts, compressed_body) = compressed_response.into_parts();
212
213 match compressed_body.collect().await {
215 Ok(collected) => {
216 let compressed_bytes = collected.to_bytes();
218
219 let final_response = axum::response::Response::from_parts(
221 compressed_parts,
222 axum::body::Body::from(compressed_bytes),
223 );
224
225 ElifResponse::from_axum_response(final_response).await
227 }
228 Err(_) => {
229 let original_response = axum::response::Response::from_parts(
231 parts_clone,
232 axum::body::Body::from(body_bytes_clone),
233 );
234 ElifResponse::from_axum_response(original_response).await
235 }
236 }
237 }
238 Err(_) => {
239 let original_response = axum::response::Response::from_parts(
241 parts_clone,
242 axum::body::Body::from(body_bytes_clone),
243 );
244 ElifResponse::from_axum_response(original_response).await
245 }
246 }
247 })
248 }
249
250 fn name(&self) -> &'static str {
251 "CompressionMiddleware"
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use crate::request::ElifRequest;
259 use crate::response::ElifResponse;
260
261 #[test]
262 fn test_compression_config() {
263 let config = CompressionConfig::default();
264 assert!(config.enable_gzip);
265 assert!(config.enable_brotli);
266 assert!(!config.enable_deflate);
267 }
268
269 #[tokio::test]
270 async fn test_compression_middleware() {
271 let middleware = CompressionMiddleware::new();
272
273 let mut headers = crate::response::headers::ElifHeaderMap::new();
275 let encoding_header =
276 crate::response::headers::ElifHeaderName::from_str("accept-encoding").unwrap();
277 let encoding_value =
278 crate::response::headers::ElifHeaderValue::from_str("gzip, br").unwrap();
279 headers.insert(encoding_header, encoding_value);
280 let request = ElifRequest::new(
281 crate::request::ElifMethod::GET,
282 "/api/data".parse().unwrap(),
283 headers,
284 );
285
286 let next = Next::new(|_req| {
288 Box::pin(async move {
289 let json_data = serde_json::json!({
290 "message": "Hello, World!".repeat(100), "data": (0..100).collect::<Vec<i32>>()
292 });
293 ElifResponse::ok().json_value(json_data)
294 })
295 });
296
297 let response = middleware.handle(request, next).await;
299
300 assert_eq!(
302 response.status_code(),
303 crate::response::status::ElifStatusCode::OK
304 );
305 }
306
307 #[tokio::test]
308 async fn test_compression_builder_pattern() {
309 let middleware = CompressionMiddleware::new()
310 .best() .gzip_only(); assert_eq!(middleware.name(), "CompressionMiddleware");
315 }
316
317 #[test]
318 fn test_compression_levels() {
319 let fast = CompressionMiddleware::new().fast();
320 let best = CompressionMiddleware::new().best();
321 let custom = CompressionMiddleware::new().level(CompressionLevel::Precise(5));
322
323 assert_eq!(fast.name(), "CompressionMiddleware");
325 assert_eq!(best.name(), "CompressionMiddleware");
326 assert_eq!(custom.name(), "CompressionMiddleware");
327 }
328
329 #[test]
330 fn test_algorithm_selection() {
331 let gzip_only = CompressionMiddleware::new().gzip_only();
332 let brotli_only = CompressionMiddleware::new().brotli_only();
333 let no_brotli = CompressionMiddleware::new().no_brotli();
334
335 assert_eq!(gzip_only.name(), "CompressionMiddleware");
337 assert_eq!(brotli_only.name(), "CompressionMiddleware");
338 assert_eq!(no_brotli.name(), "CompressionMiddleware");
339 }
340
341 #[test]
342 fn test_clone() {
343 let middleware = CompressionMiddleware::new().best();
344 let cloned = middleware.clone();
345
346 assert_eq!(cloned.name(), "CompressionMiddleware");
347 }
348}