bws_web_server/middleware/
compression.rs

1use crate::config::site::CompressionConfig;
2use brotli::enc::BrotliEncoderParams;
3use bytes::Bytes;
4use flate2::write::{DeflateEncoder, GzEncoder};
5use flate2::Compression;
6use std::io::Write;
7
8#[derive(Debug, Clone)]
9pub enum CompressionMethod {
10    None,
11    Gzip,
12    Deflate,
13    Brotli,
14}
15
16impl CompressionMethod {
17    pub fn from_accept_encoding(accept_encoding: &str) -> Self {
18        // Parse Accept-Encoding header and choose best compression method
19        let encoding = accept_encoding.to_lowercase();
20
21        // Priority order: brotli > gzip > deflate
22        if encoding.contains("br") {
23            CompressionMethod::Brotli
24        } else if encoding.contains("gzip") {
25            CompressionMethod::Gzip
26        } else if encoding.contains("deflate") {
27            CompressionMethod::Deflate
28        } else {
29            CompressionMethod::None
30        }
31    }
32
33    pub fn as_str(&self) -> &'static str {
34        match self {
35            CompressionMethod::None => "",
36            CompressionMethod::Gzip => "gzip",
37            CompressionMethod::Deflate => "deflate",
38            CompressionMethod::Brotli => "br",
39        }
40    }
41}
42
43pub struct CompressionMiddleware {
44    config: CompressionConfig,
45}
46
47impl CompressionMiddleware {
48    pub fn new(config: CompressionConfig) -> Self {
49        Self { config }
50    }
51
52    /// Check if content should be compressed based on content type and size
53    pub fn should_compress(&self, content_type: &str, content_length: usize) -> bool {
54        if !self.config.enabled {
55            return false;
56        }
57
58        // Check minimum size
59        if content_length < self.config.min_size {
60            return false;
61        }
62
63        // Check if content type is in the list of compressible types
64        self.config
65            .types
66            .iter()
67            .any(|t| content_type.to_lowercase().starts_with(&t.to_lowercase()))
68    }
69
70    /// Compress content using the specified method
71    pub fn compress(
72        &self,
73        content: &[u8],
74        method: CompressionMethod,
75    ) -> Result<Bytes, Box<dyn std::error::Error>> {
76        match method {
77            CompressionMethod::None => Ok(Bytes::copy_from_slice(content)),
78            CompressionMethod::Gzip => self.compress_gzip(content),
79            CompressionMethod::Deflate => self.compress_deflate(content),
80            CompressionMethod::Brotli => self.compress_brotli(content),
81        }
82    }
83
84    fn compress_gzip(&self, content: &[u8]) -> Result<Bytes, Box<dyn std::error::Error>> {
85        let mut encoder = GzEncoder::new(Vec::new(), Compression::new(self.config.level));
86        encoder.write_all(content)?;
87        let compressed = encoder.finish()?;
88        Ok(Bytes::from(compressed))
89    }
90
91    fn compress_deflate(&self, content: &[u8]) -> Result<Bytes, Box<dyn std::error::Error>> {
92        let mut encoder = DeflateEncoder::new(Vec::new(), Compression::new(self.config.level));
93        encoder.write_all(content)?;
94        let compressed = encoder.finish()?;
95        Ok(Bytes::from(compressed))
96    }
97
98    fn compress_brotli(&self, content: &[u8]) -> Result<Bytes, Box<dyn std::error::Error>> {
99        let params = BrotliEncoderParams {
100            quality: self.config.level as i32,
101            ..Default::default()
102        };
103
104        let mut compressed = Vec::new();
105        let mut brotli_encoder = brotli::CompressorWriter::with_params(
106            &mut compressed,
107            4096, // buffer size
108            &params,
109        );
110
111        brotli_encoder.write_all(content)?;
112        brotli_encoder.flush()?;
113        drop(brotli_encoder); // Ensure compression is finalized
114
115        Ok(Bytes::from(compressed))
116    }
117
118    /// Get the best compression method based on Accept-Encoding header
119    pub fn get_best_compression(&self, accept_encoding: Option<&str>) -> CompressionMethod {
120        match accept_encoding {
121            Some(encoding) => CompressionMethod::from_accept_encoding(encoding),
122            None => CompressionMethod::None,
123        }
124    }
125
126    /// Check if a content type is compressible
127    pub fn is_compressible_type(&self, content_type: &str) -> bool {
128        self.config
129            .types
130            .iter()
131            .any(|t| content_type.to_lowercase().starts_with(&t.to_lowercase()))
132    }
133}
134
135/// Streaming compression wrapper for large files
136pub struct StreamingCompressor {
137    method: CompressionMethod,
138    level: u32,
139}
140
141impl StreamingCompressor {
142    pub fn new(method: CompressionMethod, level: u32) -> Self {
143        Self { method, level }
144    }
145
146    /// Create a compressor writer for streaming compression
147    pub fn create_writer<W: Write + 'static>(&self, writer: W) -> Box<dyn Write> {
148        match self.method {
149            CompressionMethod::None => Box::new(writer),
150            CompressionMethod::Gzip => {
151                Box::new(GzEncoder::new(writer, Compression::new(self.level)))
152            }
153            CompressionMethod::Deflate => {
154                Box::new(DeflateEncoder::new(writer, Compression::new(self.level)))
155            }
156            CompressionMethod::Brotli => {
157                let params = BrotliEncoderParams {
158                    quality: self.level as i32,
159                    ..Default::default()
160                };
161                Box::new(brotli::CompressorWriter::with_params(
162                    writer, 4096, // buffer size
163                    &params,
164                ))
165            }
166        }
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    fn create_test_config() -> CompressionConfig {
175        CompressionConfig {
176            enabled: true,
177            types: vec![
178                "text/html".to_string(),
179                "text/css".to_string(),
180                "application/javascript".to_string(),
181                "application/json".to_string(),
182            ],
183            level: 6,
184            min_size: 1024,
185        }
186    }
187
188    #[test]
189    fn test_compression_method_from_accept_encoding() {
190        // Test brotli preference
191        assert!(matches!(
192            CompressionMethod::from_accept_encoding("gzip, deflate, br"),
193            CompressionMethod::Brotli
194        ));
195
196        // Test gzip fallback
197        assert!(matches!(
198            CompressionMethod::from_accept_encoding("gzip, deflate"),
199            CompressionMethod::Gzip
200        ));
201
202        // Test deflate fallback
203        assert!(matches!(
204            CompressionMethod::from_accept_encoding("deflate"),
205            CompressionMethod::Deflate
206        ));
207
208        // Test no compression
209        assert!(matches!(
210            CompressionMethod::from_accept_encoding("identity"),
211            CompressionMethod::None
212        ));
213    }
214
215    #[test]
216    fn test_should_compress() {
217        let middleware = CompressionMiddleware::new(create_test_config());
218
219        // Should compress HTML content above min size
220        assert!(middleware.should_compress("text/html", 2048));
221
222        // Should not compress below min size
223        assert!(!middleware.should_compress("text/html", 512));
224
225        // Should not compress non-text content
226        assert!(!middleware.should_compress("image/png", 2048));
227
228        // Should compress JSON content
229        assert!(middleware.should_compress("application/json", 2048));
230    }
231
232    #[test]
233    fn test_gzip_compression() {
234        let middleware = CompressionMiddleware::new(create_test_config());
235        // Use larger, more repetitive test data that compresses well
236        let test_data = b"Hello, World! This is a test string for compression. ".repeat(100);
237
238        let compressed = middleware
239            .compress(&test_data, CompressionMethod::Gzip)
240            .unwrap();
241
242        // Compressed data should be different and smaller for this larger test string
243        assert_ne!(compressed.as_ref(), test_data.as_slice());
244        assert!(compressed.len() < test_data.len());
245    }
246
247    #[test]
248    fn test_brotli_compression() {
249        let middleware = CompressionMiddleware::new(create_test_config());
250        let test_data = b"Hello, World! This is a test string for compression. ".repeat(100);
251
252        let compressed = middleware
253            .compress(&test_data, CompressionMethod::Brotli)
254            .unwrap();
255
256        // Compressed data should be different and smaller
257        assert_ne!(compressed.as_ref(), test_data.as_slice());
258        assert!(compressed.len() < test_data.len());
259    }
260
261    #[test]
262    fn test_compression_disabled() {
263        let mut config = create_test_config();
264        config.enabled = false;
265        let middleware = CompressionMiddleware::new(config);
266
267        // Should not compress when disabled
268        assert!(!middleware.should_compress("text/html", 2048));
269    }
270
271    #[test]
272    fn test_is_compressible_type() {
273        let middleware = CompressionMiddleware::new(create_test_config());
274
275        assert!(middleware.is_compressible_type("text/html"));
276        assert!(middleware.is_compressible_type("text/html; charset=utf-8"));
277        assert!(middleware.is_compressible_type("application/json"));
278        assert!(!middleware.is_compressible_type("image/png"));
279        assert!(!middleware.is_compressible_type("video/mp4"));
280    }
281
282    #[test]
283    fn test_best_compression_selection() {
284        let middleware = CompressionMiddleware::new(create_test_config());
285
286        // Test various Accept-Encoding headers
287        assert!(matches!(
288            middleware.get_best_compression(Some("gzip, deflate, br")),
289            CompressionMethod::Brotli
290        ));
291
292        assert!(matches!(
293            middleware.get_best_compression(Some("gzip, deflate")),
294            CompressionMethod::Gzip
295        ));
296
297        assert!(matches!(
298            middleware.get_best_compression(Some("deflate")),
299            CompressionMethod::Deflate
300        ));
301
302        assert!(matches!(
303            middleware.get_best_compression(None),
304            CompressionMethod::None
305        ));
306    }
307}