bws_web_server/middleware/
compression.rs1use crate::config::site::CompressionConfig;
2use brotli::enc::BrotliEncoderParams;
3use bytes::Bytes;
4use flate2::write::{DeflateEncoder, GzEncoder};
5use flate2::Compression;
6use std::io::Write;
7
8#[derive(Debug, Clone)]
9pub enum CompressionMethod {
10 None,
11 Gzip,
12 Deflate,
13 Brotli,
14}
15
16impl CompressionMethod {
17 pub fn from_accept_encoding(accept_encoding: &str) -> Self {
18 let encoding = accept_encoding.to_lowercase();
20
21 if encoding.contains("br") {
23 CompressionMethod::Brotli
24 } else if encoding.contains("gzip") {
25 CompressionMethod::Gzip
26 } else if encoding.contains("deflate") {
27 CompressionMethod::Deflate
28 } else {
29 CompressionMethod::None
30 }
31 }
32
33 pub fn as_str(&self) -> &'static str {
34 match self {
35 CompressionMethod::None => "",
36 CompressionMethod::Gzip => "gzip",
37 CompressionMethod::Deflate => "deflate",
38 CompressionMethod::Brotli => "br",
39 }
40 }
41}
42
43pub struct CompressionMiddleware {
44 config: CompressionConfig,
45}
46
47impl CompressionMiddleware {
48 pub fn new(config: CompressionConfig) -> Self {
49 Self { config }
50 }
51
52 pub fn should_compress(&self, content_type: &str, content_length: usize) -> bool {
54 if !self.config.enabled {
55 return false;
56 }
57
58 if content_length < self.config.min_size {
60 return false;
61 }
62
63 self.config
65 .types
66 .iter()
67 .any(|t| content_type.to_lowercase().starts_with(&t.to_lowercase()))
68 }
69
70 pub fn compress(
72 &self,
73 content: &[u8],
74 method: CompressionMethod,
75 ) -> Result<Bytes, Box<dyn std::error::Error>> {
76 match method {
77 CompressionMethod::None => Ok(Bytes::copy_from_slice(content)),
78 CompressionMethod::Gzip => self.compress_gzip(content),
79 CompressionMethod::Deflate => self.compress_deflate(content),
80 CompressionMethod::Brotli => self.compress_brotli(content),
81 }
82 }
83
84 fn compress_gzip(&self, content: &[u8]) -> Result<Bytes, Box<dyn std::error::Error>> {
85 let mut encoder = GzEncoder::new(Vec::new(), Compression::new(self.config.level));
86 encoder.write_all(content)?;
87 let compressed = encoder.finish()?;
88 Ok(Bytes::from(compressed))
89 }
90
91 fn compress_deflate(&self, content: &[u8]) -> Result<Bytes, Box<dyn std::error::Error>> {
92 let mut encoder = DeflateEncoder::new(Vec::new(), Compression::new(self.config.level));
93 encoder.write_all(content)?;
94 let compressed = encoder.finish()?;
95 Ok(Bytes::from(compressed))
96 }
97
98 fn compress_brotli(&self, content: &[u8]) -> Result<Bytes, Box<dyn std::error::Error>> {
99 let params = BrotliEncoderParams {
100 quality: self.config.level as i32,
101 ..Default::default()
102 };
103
104 let mut compressed = Vec::new();
105 let mut brotli_encoder = brotli::CompressorWriter::with_params(
106 &mut compressed,
107 4096, ¶ms,
109 );
110
111 brotli_encoder.write_all(content)?;
112 brotli_encoder.flush()?;
113 drop(brotli_encoder); Ok(Bytes::from(compressed))
116 }
117
118 pub fn get_best_compression(&self, accept_encoding: Option<&str>) -> CompressionMethod {
120 match accept_encoding {
121 Some(encoding) => CompressionMethod::from_accept_encoding(encoding),
122 None => CompressionMethod::None,
123 }
124 }
125
126 pub fn is_compressible_type(&self, content_type: &str) -> bool {
128 self.config
129 .types
130 .iter()
131 .any(|t| content_type.to_lowercase().starts_with(&t.to_lowercase()))
132 }
133}
134
135pub struct StreamingCompressor {
137 method: CompressionMethod,
138 level: u32,
139}
140
141impl StreamingCompressor {
142 pub fn new(method: CompressionMethod, level: u32) -> Self {
143 Self { method, level }
144 }
145
146 pub fn create_writer<W: Write + 'static>(&self, writer: W) -> Box<dyn Write> {
148 match self.method {
149 CompressionMethod::None => Box::new(writer),
150 CompressionMethod::Gzip => {
151 Box::new(GzEncoder::new(writer, Compression::new(self.level)))
152 }
153 CompressionMethod::Deflate => {
154 Box::new(DeflateEncoder::new(writer, Compression::new(self.level)))
155 }
156 CompressionMethod::Brotli => {
157 let params = BrotliEncoderParams {
158 quality: self.level as i32,
159 ..Default::default()
160 };
161 Box::new(brotli::CompressorWriter::with_params(
162 writer, 4096, ¶ms,
164 ))
165 }
166 }
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 fn create_test_config() -> CompressionConfig {
175 CompressionConfig {
176 enabled: true,
177 types: vec![
178 "text/html".to_string(),
179 "text/css".to_string(),
180 "application/javascript".to_string(),
181 "application/json".to_string(),
182 ],
183 level: 6,
184 min_size: 1024,
185 }
186 }
187
188 #[test]
189 fn test_compression_method_from_accept_encoding() {
190 assert!(matches!(
192 CompressionMethod::from_accept_encoding("gzip, deflate, br"),
193 CompressionMethod::Brotli
194 ));
195
196 assert!(matches!(
198 CompressionMethod::from_accept_encoding("gzip, deflate"),
199 CompressionMethod::Gzip
200 ));
201
202 assert!(matches!(
204 CompressionMethod::from_accept_encoding("deflate"),
205 CompressionMethod::Deflate
206 ));
207
208 assert!(matches!(
210 CompressionMethod::from_accept_encoding("identity"),
211 CompressionMethod::None
212 ));
213 }
214
215 #[test]
216 fn test_should_compress() {
217 let middleware = CompressionMiddleware::new(create_test_config());
218
219 assert!(middleware.should_compress("text/html", 2048));
221
222 assert!(!middleware.should_compress("text/html", 512));
224
225 assert!(!middleware.should_compress("image/png", 2048));
227
228 assert!(middleware.should_compress("application/json", 2048));
230 }
231
232 #[test]
233 fn test_gzip_compression() {
234 let middleware = CompressionMiddleware::new(create_test_config());
235 let test_data = b"Hello, World! This is a test string for compression. ".repeat(100);
237
238 let compressed = middleware
239 .compress(&test_data, CompressionMethod::Gzip)
240 .unwrap();
241
242 assert_ne!(compressed.as_ref(), test_data.as_slice());
244 assert!(compressed.len() < test_data.len());
245 }
246
247 #[test]
248 fn test_brotli_compression() {
249 let middleware = CompressionMiddleware::new(create_test_config());
250 let test_data = b"Hello, World! This is a test string for compression. ".repeat(100);
251
252 let compressed = middleware
253 .compress(&test_data, CompressionMethod::Brotli)
254 .unwrap();
255
256 assert_ne!(compressed.as_ref(), test_data.as_slice());
258 assert!(compressed.len() < test_data.len());
259 }
260
261 #[test]
262 fn test_compression_disabled() {
263 let mut config = create_test_config();
264 config.enabled = false;
265 let middleware = CompressionMiddleware::new(config);
266
267 assert!(!middleware.should_compress("text/html", 2048));
269 }
270
271 #[test]
272 fn test_is_compressible_type() {
273 let middleware = CompressionMiddleware::new(create_test_config());
274
275 assert!(middleware.is_compressible_type("text/html"));
276 assert!(middleware.is_compressible_type("text/html; charset=utf-8"));
277 assert!(middleware.is_compressible_type("application/json"));
278 assert!(!middleware.is_compressible_type("image/png"));
279 assert!(!middleware.is_compressible_type("video/mp4"));
280 }
281
282 #[test]
283 fn test_best_compression_selection() {
284 let middleware = CompressionMiddleware::new(create_test_config());
285
286 assert!(matches!(
288 middleware.get_best_compression(Some("gzip, deflate, br")),
289 CompressionMethod::Brotli
290 ));
291
292 assert!(matches!(
293 middleware.get_best_compression(Some("gzip, deflate")),
294 CompressionMethod::Gzip
295 ));
296
297 assert!(matches!(
298 middleware.get_best_compression(Some("deflate")),
299 CompressionMethod::Deflate
300 ));
301
302 assert!(matches!(
303 middleware.get_best_compression(None),
304 CompressionMethod::None
305 ));
306 }
307}