Skip to main content

gatel_core/hoops/
decompress.rs

1use std::io::Cursor;
2
3use async_compression::tokio::bufread::{BrotliDecoder, DeflateDecoder, GzipDecoder, ZstdDecoder};
4use http::header::{CONTENT_ENCODING, CONTENT_LENGTH};
5use salvo::http::ReqBody;
6use salvo::{Depot, FlowCtrl, Request, Response, async_trait};
7use tokio::io::AsyncReadExt;
8use tracing::debug;
9
10/// Request body decompression middleware.
11///
12/// Inspects the `Content-Encoding` request header and decompresses the body
13/// before forwarding to the next handler. Supports gzip, brotli, zstd, and
14/// deflate. After decompression, the `Content-Encoding` header is removed and
15/// `Content-Length` is updated.
16pub struct DecompressHoop {
17    /// Maximum decompressed body size to prevent decompression bombs.
18    max_size: usize,
19}
20
21impl DecompressHoop {
22    /// Create a new decompression middleware.
23    ///
24    /// `max_size` limits the decompressed output to prevent zip bombs.
25    /// Default: 64 MiB.
26    pub fn new(max_size: Option<usize>) -> Self {
27        Self {
28            max_size: max_size.unwrap_or(64 * 1024 * 1024),
29        }
30    }
31}
32
33#[async_trait]
34impl salvo::Handler for DecompressHoop {
35    async fn handle(
36        &self,
37        req: &mut Request,
38        depot: &mut Depot,
39        res: &mut Response,
40        ctrl: &mut FlowCtrl,
41    ) {
42        let encoding = req
43            .headers()
44            .get(CONTENT_ENCODING)
45            .and_then(|v| v.to_str().ok())
46            .map(|s| s.to_ascii_lowercase());
47
48        let encoding = match encoding {
49            Some(e) if e == "gzip" || e == "br" || e == "zstd" || e == "deflate" => e,
50            _ => {
51                // No compression or unsupported — pass through.
52                ctrl.call_next(req, depot, res).await;
53                return;
54            }
55        };
56
57        // Read the compressed body.
58        let compressed = match req.payload().await {
59            Ok(bytes) => bytes.to_vec(),
60            Err(_) => {
61                ctrl.call_next(req, depot, res).await;
62                return;
63            }
64        };
65
66        if compressed.is_empty() {
67            ctrl.call_next(req, depot, res).await;
68            return;
69        }
70
71        // Decompress.
72        let decompressed = match decompress_bytes(&compressed, &encoding, self.max_size).await {
73            Ok(d) => d,
74            Err(e) => {
75                debug!(error = %e, encoding = encoding.as_str(), "request decompression failed");
76                res.status_code(http::StatusCode::BAD_REQUEST);
77                res.body("decompression failed");
78                ctrl.skip_rest();
79                return;
80            }
81        };
82
83        debug!(
84            encoding = encoding.as_str(),
85            compressed = compressed.len(),
86            decompressed = decompressed.len(),
87            "decompressed request body"
88        );
89
90        // Replace the body with the decompressed content.
91        req.headers_mut().remove(CONTENT_ENCODING);
92        req.headers_mut()
93            .insert(CONTENT_LENGTH, decompressed.len().into());
94        *req.body_mut() = ReqBody::Once(decompressed.into());
95
96        ctrl.call_next(req, depot, res).await;
97    }
98}
99
100async fn decompress_bytes(data: &[u8], encoding: &str, max_size: usize) -> Result<Vec<u8>, String> {
101    let cursor = Cursor::new(data);
102    let reader = tokio::io::BufReader::new(cursor);
103    let mut output = Vec::new();
104
105    match encoding {
106        "gzip" => {
107            let mut decoder = GzipDecoder::new(reader);
108            read_limited(&mut decoder, &mut output, max_size).await?;
109        }
110        "br" => {
111            let mut decoder = BrotliDecoder::new(reader);
112            read_limited(&mut decoder, &mut output, max_size).await?;
113        }
114        "zstd" => {
115            let mut decoder = ZstdDecoder::new(reader);
116            read_limited(&mut decoder, &mut output, max_size).await?;
117        }
118        "deflate" => {
119            let mut decoder = DeflateDecoder::new(reader);
120            read_limited(&mut decoder, &mut output, max_size).await?;
121        }
122        _ => return Err(format!("unsupported encoding: {encoding}")),
123    }
124
125    Ok(output)
126}
127
128async fn read_limited<R: tokio::io::AsyncRead + Unpin>(
129    reader: &mut R,
130    output: &mut Vec<u8>,
131    max_size: usize,
132) -> Result<(), String> {
133    let mut buf = [0u8; 8192];
134    loop {
135        let n = reader
136            .read(&mut buf)
137            .await
138            .map_err(|e| format!("decompression error: {e}"))?;
139        if n == 0 {
140            break;
141        }
142        if output.len() + n > max_size {
143            return Err(format!(
144                "decompressed body exceeds limit ({max_size} bytes)"
145            ));
146        }
147        output.extend_from_slice(&buf[..n]);
148    }
149    Ok(())
150}