gatel_core/hoops/
decompress.rs1use std::io::Cursor;
2
3use async_compression::tokio::bufread::{BrotliDecoder, DeflateDecoder, GzipDecoder, ZstdDecoder};
4use http::header::{CONTENT_ENCODING, CONTENT_LENGTH};
5use salvo::http::ReqBody;
6use salvo::{Depot, FlowCtrl, Request, Response, async_trait};
7use tokio::io::AsyncReadExt;
8use tracing::debug;
9
10pub struct DecompressHoop {
17 max_size: usize,
19}
20
21impl DecompressHoop {
22 pub fn new(max_size: Option<usize>) -> Self {
27 Self {
28 max_size: max_size.unwrap_or(64 * 1024 * 1024),
29 }
30 }
31}
32
33#[async_trait]
34impl salvo::Handler for DecompressHoop {
35 async fn handle(
36 &self,
37 req: &mut Request,
38 depot: &mut Depot,
39 res: &mut Response,
40 ctrl: &mut FlowCtrl,
41 ) {
42 let encoding = req
43 .headers()
44 .get(CONTENT_ENCODING)
45 .and_then(|v| v.to_str().ok())
46 .map(|s| s.to_ascii_lowercase());
47
48 let encoding = match encoding {
49 Some(e) if e == "gzip" || e == "br" || e == "zstd" || e == "deflate" => e,
50 _ => {
51 ctrl.call_next(req, depot, res).await;
53 return;
54 }
55 };
56
57 let compressed = match req.payload().await {
59 Ok(bytes) => bytes.to_vec(),
60 Err(_) => {
61 ctrl.call_next(req, depot, res).await;
62 return;
63 }
64 };
65
66 if compressed.is_empty() {
67 ctrl.call_next(req, depot, res).await;
68 return;
69 }
70
71 let decompressed = match decompress_bytes(&compressed, &encoding, self.max_size).await {
73 Ok(d) => d,
74 Err(e) => {
75 debug!(error = %e, encoding = encoding.as_str(), "request decompression failed");
76 res.status_code(http::StatusCode::BAD_REQUEST);
77 res.body("decompression failed");
78 ctrl.skip_rest();
79 return;
80 }
81 };
82
83 debug!(
84 encoding = encoding.as_str(),
85 compressed = compressed.len(),
86 decompressed = decompressed.len(),
87 "decompressed request body"
88 );
89
90 req.headers_mut().remove(CONTENT_ENCODING);
92 req.headers_mut()
93 .insert(CONTENT_LENGTH, decompressed.len().into());
94 *req.body_mut() = ReqBody::Once(decompressed.into());
95
96 ctrl.call_next(req, depot, res).await;
97 }
98}
99
100async fn decompress_bytes(data: &[u8], encoding: &str, max_size: usize) -> Result<Vec<u8>, String> {
101 let cursor = Cursor::new(data);
102 let reader = tokio::io::BufReader::new(cursor);
103 let mut output = Vec::new();
104
105 match encoding {
106 "gzip" => {
107 let mut decoder = GzipDecoder::new(reader);
108 read_limited(&mut decoder, &mut output, max_size).await?;
109 }
110 "br" => {
111 let mut decoder = BrotliDecoder::new(reader);
112 read_limited(&mut decoder, &mut output, max_size).await?;
113 }
114 "zstd" => {
115 let mut decoder = ZstdDecoder::new(reader);
116 read_limited(&mut decoder, &mut output, max_size).await?;
117 }
118 "deflate" => {
119 let mut decoder = DeflateDecoder::new(reader);
120 read_limited(&mut decoder, &mut output, max_size).await?;
121 }
122 _ => return Err(format!("unsupported encoding: {encoding}")),
123 }
124
125 Ok(output)
126}
127
128async fn read_limited<R: tokio::io::AsyncRead + Unpin>(
129 reader: &mut R,
130 output: &mut Vec<u8>,
131 max_size: usize,
132) -> Result<(), String> {
133 let mut buf = [0u8; 8192];
134 loop {
135 let n = reader
136 .read(&mut buf)
137 .await
138 .map_err(|e| format!("decompression error: {e}"))?;
139 if n == 0 {
140 break;
141 }
142 if output.len() + n > max_size {
143 return Err(format!(
144 "decompressed body exceeds limit ({max_size} bytes)"
145 ));
146 }
147 output.extend_from_slice(&buf[..n]);
148 }
149 Ok(())
150}