Skip to main content

gatel_core/hoops/
compress.rs

1use std::io::Cursor;
2
3use async_compression::Level;
4use async_compression::tokio::bufread::{BrotliEncoder, DeflateEncoder, GzipEncoder, ZstdEncoder};
5use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE};
6use http_body_util::BodyExt;
7use salvo::http::ResBody;
8use salvo::{Depot, FlowCtrl, Request, Response, async_trait};
9use tokio::io::AsyncReadExt;
10use tracing::debug;
11
12/// Minimum response size (in bytes) to bother compressing.
13const MIN_COMPRESS_SIZE: usize = 256;
14
15/// Supported compression encodings, in preference order.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17enum Encoding {
18    Zstd,
19    Brotli,
20    Gzip,
21    Deflate,
22}
23
24impl Encoding {
25    fn as_str(self) -> &'static str {
26        match self {
27            Encoding::Zstd => "zstd",
28            Encoding::Brotli => "br",
29            Encoding::Gzip => "gzip",
30            Encoding::Deflate => "deflate",
31        }
32    }
33}
34
35/// Response compression middleware.
36///
37/// Inspects the `Accept-Encoding` request header, selects the best supported
38/// encoding, and compresses the response body if it is a compressible content
39/// type and large enough to be worth compressing.
40pub struct CompressHoop {
41    /// Enabled encodings, in preference order.
42    enabled: Vec<Encoding>,
43    /// Optional compression level. When set, uses `Level::Precise(level)`.
44    level: Option<u32>,
45}
46
47impl CompressHoop {
48    /// Create from a list of encoding names (e.g. `["gzip", "zstd", "br"]`) and an optional level.
49    pub fn new(encodings: &[String], level: Option<u32>) -> Self {
50        let mut enabled = Vec::new();
51        for name in encodings {
52            match name.as_str() {
53                "gzip" => enabled.push(Encoding::Gzip),
54                "zstd" => enabled.push(Encoding::Zstd),
55                "br" | "brotli" => enabled.push(Encoding::Brotli),
56                "deflate" => enabled.push(Encoding::Deflate),
57                other => {
58                    debug!(encoding = other, "unknown encoding requested, skipping");
59                }
60            }
61        }
62        // If nothing valid was provided, default to all four.
63        if enabled.is_empty() {
64            enabled = vec![
65                Encoding::Zstd,
66                Encoding::Brotli,
67                Encoding::Gzip,
68                Encoding::Deflate,
69            ];
70        }
71        Self { enabled, level }
72    }
73}
74
75#[async_trait]
76impl salvo::Handler for CompressHoop {
77    async fn handle(
78        &self,
79        req: &mut Request,
80        depot: &mut Depot,
81        res: &mut Response,
82        ctrl: &mut FlowCtrl,
83    ) {
84        // Determine which encoding the client accepts that we also support.
85        let chosen = choose_encoding(req.headers(), &self.enabled);
86
87        ctrl.call_next(req, depot, res).await;
88
89        // If no acceptable encoding, return uncompressed.
90        let encoding = match chosen {
91            Some(e) => e,
92            None => return,
93        };
94
95        // Skip if response already has Content-Encoding.
96        if res.headers().contains_key(CONTENT_ENCODING) {
97            return;
98        }
99
100        // Only compress compressible content types.
101        if !is_compressible_content_type(res.headers()) {
102            return;
103        }
104
105        // Take the body and collect it.
106        let body = res.take_body();
107        let body_bytes = match collect_res_body(body).await {
108            Ok(bytes) => bytes,
109            Err(_) => return,
110        };
111
112        // Skip if body is too small.
113        if body_bytes.len() < MIN_COMPRESS_SIZE {
114            res.body(body_bytes);
115            return;
116        }
117
118        // Compress.
119        let compressed = match compress_bytes(&body_bytes, encoding, self.level).await {
120            Ok(c) => c,
121            Err(_) => {
122                res.body(body_bytes);
123                return;
124            }
125        };
126
127        debug!(
128            encoding = encoding.as_str(),
129            original = body_bytes.len(),
130            compressed = compressed.len(),
131            "compressed response"
132        );
133
134        // Update headers.
135        res.headers_mut()
136            .insert(CONTENT_ENCODING, encoding.as_str().parse().unwrap());
137        res.headers_mut().remove(CONTENT_LENGTH);
138        res.headers_mut()
139            .insert(CONTENT_LENGTH, compressed.len().into());
140
141        res.body(compressed);
142    }
143}
144
145/// Collect a Salvo ResBody into bytes (public for reuse by other middleware).
146pub async fn collect_res_body_bytes(body: ResBody) -> Result<Vec<u8>, ()> {
147    collect_res_body(body).await
148}
149
150/// Collect a Salvo ResBody into bytes.
151async fn collect_res_body(body: ResBody) -> Result<Vec<u8>, ()> {
152    match body {
153        ResBody::None => Ok(Vec::new()),
154        ResBody::Once(bytes) => Ok(bytes.to_vec()),
155        ResBody::Boxed(boxed) => {
156            let collected = boxed.collect().await.map_err(|_| ())?;
157            Ok(collected.to_bytes().to_vec())
158        }
159        other => {
160            // For Hyper and other body types, try to collect through the Body trait.
161            use http_body::Body;
162            let mut buf = Vec::new();
163            let mut pinned = Box::pin(other);
164            loop {
165                match std::future::poll_fn(|cx| pinned.as_mut().poll_frame(cx)).await {
166                    Some(Ok(frame)) => {
167                        if let Ok(data) = frame.into_data() {
168                            buf.extend_from_slice(&data);
169                        }
170                    }
171                    Some(Err(_)) => return Err(()),
172                    None => break,
173                }
174            }
175            Ok(buf)
176        }
177    }
178}
179
180/// Choose the best encoding from the client's Accept-Encoding that we support.
181fn choose_encoding(headers: &http::HeaderMap, enabled: &[Encoding]) -> Option<Encoding> {
182    let accept = headers.get(ACCEPT_ENCODING)?.to_str().ok()?;
183    // Parse quality values (simplified: we just check presence, not q values).
184    // Order by our preference (the order in `enabled`).
185    for enc in enabled {
186        let token = enc.as_str();
187        if accept.contains(token) || accept.contains("*") {
188            return Some(*enc);
189        }
190    }
191    None
192}
193
194/// Returns true if the response Content-Type is a compressible text-like MIME.
195fn is_compressible_content_type(headers: &http::HeaderMap) -> bool {
196    let ct = match headers.get(CONTENT_TYPE) {
197        Some(v) => match v.to_str() {
198            Ok(s) => s.to_ascii_lowercase(),
199            Err(_) => return false,
200        },
201        // No Content-Type — don't compress.
202        None => return false,
203    };
204
205    // text/* is always compressible.
206    if ct.starts_with("text/") {
207        return true;
208    }
209
210    const COMPRESSIBLE: &[&str] = &[
211        "application/json",
212        "application/javascript",
213        "application/xml",
214        "application/xhtml+xml",
215        "application/rss+xml",
216        "application/atom+xml",
217        "application/wasm",
218        "application/manifest+json",
219        "application/ld+json",
220        "application/graphql+json",
221        "application/geo+json",
222        "application/vnd.api+json",
223        "image/svg+xml",
224    ];
225
226    for mime in COMPRESSIBLE {
227        if ct.starts_with(mime) {
228            return true;
229        }
230    }
231
232    false
233}
234
235/// Compress bytes with the given encoding and optional level.
236async fn compress_bytes(
237    data: &[u8],
238    encoding: Encoding,
239    level: Option<u32>,
240) -> Result<Vec<u8>, crate::ProxyError> {
241    let cursor = Cursor::new(data);
242    let reader = tokio::io::BufReader::new(cursor);
243    let mut output = Vec::new();
244    let compress_level = level.map(|l| Level::Precise(l as i32)).unwrap_or_default();
245
246    match encoding {
247        Encoding::Gzip => {
248            let mut encoder = GzipEncoder::with_quality(reader, compress_level);
249            encoder
250                .read_to_end(&mut output)
251                .await
252                .map_err(|e| crate::ProxyError::Internal(format!("gzip compression error: {e}")))?;
253        }
254        Encoding::Zstd => {
255            let mut encoder = ZstdEncoder::with_quality(reader, compress_level);
256            encoder
257                .read_to_end(&mut output)
258                .await
259                .map_err(|e| crate::ProxyError::Internal(format!("zstd compression error: {e}")))?;
260        }
261        Encoding::Brotli => {
262            let mut encoder = BrotliEncoder::with_quality(reader, compress_level);
263            encoder.read_to_end(&mut output).await.map_err(|e| {
264                crate::ProxyError::Internal(format!("brotli compression error: {e}"))
265            })?;
266        }
267        Encoding::Deflate => {
268            let mut encoder = DeflateEncoder::with_quality(reader, compress_level);
269            encoder.read_to_end(&mut output).await.map_err(|e| {
270                crate::ProxyError::Internal(format!("deflate compression error: {e}"))
271            })?;
272        }
273    }
274
275    Ok(output)
276}