warpdrive_proxy/middleware/
compression.rs

1//! Response compression middleware
2//!
3//! Implements Brotli and gzip compression for responses based on client
4//! Accept-Encoding headers.
5//!
6//! # Compression Strategy
7//!
8//! - Only compress if client supports gzip (Accept-Encoding: gzip)
9//! - Skip compression for already-compressed content (Content-Encoding present)
10//! - Skip compression for small responses (< 1KB)
11//! - Skip compression for non-compressible types (images, video, etc.)
12//!
13//! # Note on Pingora Integration
14//!
15//! Pingora doesn't have built-in response body modification in filters, so this
16//! middleware sets headers to signal compression intent. Actual compression happens
17//! in the response body stream using async-compression.
18
19use async_trait::async_trait;
20use pingora::http::ResponseHeader;
21use pingora::prelude::*;
22use tracing::debug;
23
24use super::{CompressionEncoding, Middleware, MiddlewareContext};
25
26/// Compression middleware
27///
28/// Adds compression support for responses based on client capabilities.
29pub struct CompressionMiddleware {
30    /// Minimum response size to compress (bytes)
31    min_compress_size: usize,
32}
33
34impl CompressionMiddleware {
35    /// Create new compression middleware
36    pub fn new() -> Self {
37        Self {
38            min_compress_size: 1024, // 1KB minimum
39        }
40    }
41
42    /// Check if content type should be compressed
43    pub(crate) fn should_compress_content_type(content_type: &str) -> bool {
44        // Compress text-based content
45        let compressible_prefixes = [
46            "text/",
47            "application/json",
48            "application/javascript",
49            "application/xml",
50            "application/x-javascript",
51            "application/xhtml+xml",
52            "application/rss+xml",
53            "application/atom+xml",
54        ];
55
56        compressible_prefixes
57            .iter()
58            .any(|prefix| content_type.starts_with(prefix))
59    }
60
61    /// Determine the best compression encoding supported by the client
62    ///
63    /// Returns the preferred encoding in order: Brotli > Gzip > None
64    /// Brotli provides better compression ratios (~15-25% smaller than gzip)
65    fn best_supported_encoding(session: &Session) -> Option<CompressionEncoding> {
66        if let Some(accept_encoding) = session.req_header().headers.get("accept-encoding") {
67            if let Ok(value) = accept_encoding.to_str() {
68                let encodings: Vec<&str> = value.split(',').map(|s| s.trim()).collect();
69
70                // Prefer Brotli (better compression ratio)
71                if encodings.contains(&"br") {
72                    return Some(CompressionEncoding::Brotli);
73                }
74
75                // Fallback to Gzip (wider support)
76                if encodings.contains(&"gzip") {
77                    return Some(CompressionEncoding::Gzip);
78                }
79            }
80        }
81        None
82    }
83}
84
85impl Default for CompressionMiddleware {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91#[async_trait]
92impl Middleware for CompressionMiddleware {
93    /// Process response to determine if compression should be applied
94    ///
95    /// This sets a custom header (X-WarpDrive-Compress) to signal that compression
96    /// should be applied to the response body downstream.
97    async fn response_filter(
98        &self,
99        session: &mut Session,
100        upstream_response: &mut ResponseHeader,
101        ctx: &mut MiddlewareContext,
102    ) -> Result<()> {
103        // Detect streaming responses
104        // Check for Content-Type: text/event-stream (Server-Sent Events)
105        if let Some(content_type) = upstream_response.headers.get("content-type") {
106            if let Ok(ct_str) = content_type.to_str() {
107                if ct_str.contains("text/event-stream") {
108                    debug!("Streaming detected: SSE (text/event-stream)");
109                    ctx.streaming = true;
110                    return Ok(());
111                }
112            }
113        }
114
115        // Check for X-Accel-Buffering: no (nginx compatibility)
116        if let Some(buffering) = upstream_response.headers.get("x-accel-buffering") {
117            if let Ok(value) = buffering.to_str() {
118                if value.to_lowercase() == "no" {
119                    debug!("Streaming detected: X-Accel-Buffering: no");
120                    ctx.streaming = true;
121                    return Ok(());
122                }
123            }
124        }
125
126        // Determine best compression encoding supported by client
127        let encoding = match Self::best_supported_encoding(session) {
128            Some(enc) => enc,
129            None => {
130                debug!("Client doesn't support compression");
131                return Ok(());
132            }
133        };
134
135        // Skip if sendfile already took over body streaming
136        if ctx.sendfile.active {
137            debug!("Skipping compression: sendfile response");
138            return Ok(());
139        }
140
141        // Skip if already compressed
142        if upstream_response.headers.contains_key("content-encoding") {
143            debug!("Skipping compression: already encoded");
144            return Ok(());
145        }
146
147        // Check content type
148        if let Some(content_type) = upstream_response.headers.get("content-type") {
149            if let Ok(ct_str) = content_type.to_str() {
150                if !Self::should_compress_content_type(ct_str) {
151                    debug!(
152                        "Skipping compression: non-compressible content type {}",
153                        ct_str
154                    );
155                    return Ok(());
156                }
157            }
158        }
159
160        // Check content length
161        if let Some(content_length) = upstream_response.headers.get("content-length") {
162            if let Ok(length_str) = content_length.to_str() {
163                if let Ok(length) = length_str.parse::<usize>() {
164                    if length < self.min_compress_size {
165                        debug!(
166                            "Skipping compression: response too small ({} bytes)",
167                            length
168                        );
169                        return Ok(());
170                    }
171                }
172            }
173        }
174
175        // At this point, compression is eligible
176        debug!("Response eligible for {:?} compression", encoding);
177
178        // Set Vary header to indicate response varies by Accept-Encoding
179        upstream_response.insert_header("Vary", "Accept-Encoding")?;
180
181        // Set Content-Encoding and remove Content-Length (size changes post-compression)
182        let encoding_str = match encoding {
183            CompressionEncoding::Brotli => "br",
184            CompressionEncoding::Gzip => "gzip",
185        };
186        upstream_response.insert_header("Content-Encoding", encoding_str)?;
187        upstream_response.remove_header("Content-Length");
188
189        // Initialize compression state in context for body filter
190        if !ctx.compression.is_enabled() {
191            ctx.compression.enable(encoding);
192        }
193
194        Ok(())
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn test_should_compress_content_type() {
204        assert!(CompressionMiddleware::should_compress_content_type(
205            "text/html"
206        ));
207        assert!(CompressionMiddleware::should_compress_content_type(
208            "text/plain"
209        ));
210        assert!(CompressionMiddleware::should_compress_content_type(
211            "application/json"
212        ));
213        assert!(CompressionMiddleware::should_compress_content_type(
214            "application/javascript"
215        ));
216        assert!(CompressionMiddleware::should_compress_content_type(
217            "application/xml"
218        ));
219
220        assert!(!CompressionMiddleware::should_compress_content_type(
221            "image/png"
222        ));
223        assert!(!CompressionMiddleware::should_compress_content_type(
224            "image/jpeg"
225        ));
226        assert!(!CompressionMiddleware::should_compress_content_type(
227            "video/mp4"
228        ));
229        assert!(!CompressionMiddleware::should_compress_content_type(
230            "application/pdf"
231        ));
232    }
233
234    #[test]
235    fn test_compression_middleware_creation() {
236        let middleware = CompressionMiddleware::new();
237        assert_eq!(middleware.min_compress_size, 1024);
238
239        let middleware_default = CompressionMiddleware::default();
240        assert_eq!(middleware_default.min_compress_size, 1024);
241    }
242
243    #[test]
244    fn test_sse_streaming_detection() {
245        // Test that SSE content type is detected
246        use MiddlewareContext;
247        use pingora::http::ResponseHeader;
248
249        let mut ctx = MiddlewareContext::default();
250        assert!(!ctx.streaming, "Context should start with streaming=false");
251
252        // Create response with SSE content type
253        let mut response = ResponseHeader::build(200, None).unwrap();
254        response
255            .insert_header("content-type", "text/event-stream")
256            .unwrap();
257
258        // Content type should contain event-stream
259        if let Some(ct) = response.headers.get("content-type") {
260            if let Ok(ct_str) = ct.to_str() {
261                assert!(ct_str.contains("text/event-stream"));
262            }
263        }
264    }
265
266    #[test]
267    fn test_x_accel_buffering_detection() {
268        use MiddlewareContext;
269        use pingora::http::ResponseHeader;
270
271        let mut ctx = MiddlewareContext::default();
272        assert!(!ctx.streaming);
273
274        // Test X-Accel-Buffering: no
275        let mut response = ResponseHeader::build(200, None).unwrap();
276        response.insert_header("x-accel-buffering", "no").unwrap();
277
278        if let Some(buffering) = response.headers.get("x-accel-buffering") {
279            if let Ok(value) = buffering.to_str() {
280                assert_eq!(value.to_lowercase(), "no");
281            }
282        }
283    }
284
285    #[test]
286    fn test_brotli_preferred_over_gzip() {
287        use crate::middleware::CompressionEncoding;
288
289        // Verify that Brotli enum value comes before Gzip in definition order
290        // This ensures best_supported_encoding() checks Brotli first
291        let br = CompressionEncoding::Brotli;
292        let gz = CompressionEncoding::Gzip;
293
294        // Just verify both variants exist and are distinct
295        assert_ne!(br, gz);
296    }
297
298    #[test]
299    fn test_compression_encoding_equality() {
300        use crate::middleware::CompressionEncoding;
301
302        assert_eq!(CompressionEncoding::Brotli, CompressionEncoding::Brotli);
303        assert_eq!(CompressionEncoding::Gzip, CompressionEncoding::Gzip);
304        assert_ne!(CompressionEncoding::Brotli, CompressionEncoding::Gzip);
305    }
306}