1use std::io::Cursor;
2
3use async_compression::Level;
4use async_compression::tokio::bufread::{BrotliEncoder, DeflateEncoder, GzipEncoder, ZstdEncoder};
5use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE};
6use http_body_util::BodyExt;
7use salvo::http::ResBody;
8use salvo::{Depot, FlowCtrl, Request, Response, async_trait};
9use tokio::io::AsyncReadExt;
10use tracing::debug;
11
12const MIN_COMPRESS_SIZE: usize = 256;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17enum Encoding {
18 Zstd,
19 Brotli,
20 Gzip,
21 Deflate,
22}
23
24impl Encoding {
25 fn as_str(self) -> &'static str {
26 match self {
27 Encoding::Zstd => "zstd",
28 Encoding::Brotli => "br",
29 Encoding::Gzip => "gzip",
30 Encoding::Deflate => "deflate",
31 }
32 }
33}
34
35pub struct CompressHoop {
41 enabled: Vec<Encoding>,
43 level: Option<u32>,
45}
46
47impl CompressHoop {
48 pub fn new(encodings: &[String], level: Option<u32>) -> Self {
50 let mut enabled = Vec::new();
51 for name in encodings {
52 match name.as_str() {
53 "gzip" => enabled.push(Encoding::Gzip),
54 "zstd" => enabled.push(Encoding::Zstd),
55 "br" | "brotli" => enabled.push(Encoding::Brotli),
56 "deflate" => enabled.push(Encoding::Deflate),
57 other => {
58 debug!(encoding = other, "unknown encoding requested, skipping");
59 }
60 }
61 }
62 if enabled.is_empty() {
64 enabled = vec![
65 Encoding::Zstd,
66 Encoding::Brotli,
67 Encoding::Gzip,
68 Encoding::Deflate,
69 ];
70 }
71 Self { enabled, level }
72 }
73}
74
75#[async_trait]
76impl salvo::Handler for CompressHoop {
77 async fn handle(
78 &self,
79 req: &mut Request,
80 depot: &mut Depot,
81 res: &mut Response,
82 ctrl: &mut FlowCtrl,
83 ) {
84 let chosen = choose_encoding(req.headers(), &self.enabled);
86
87 ctrl.call_next(req, depot, res).await;
88
89 let encoding = match chosen {
91 Some(e) => e,
92 None => return,
93 };
94
95 if res.headers().contains_key(CONTENT_ENCODING) {
97 return;
98 }
99
100 if !is_compressible_content_type(res.headers()) {
102 return;
103 }
104
105 let body = res.take_body();
107 let body_bytes = match collect_res_body(body).await {
108 Ok(bytes) => bytes,
109 Err(_) => return,
110 };
111
112 if body_bytes.len() < MIN_COMPRESS_SIZE {
114 res.body(body_bytes);
115 return;
116 }
117
118 let compressed = match compress_bytes(&body_bytes, encoding, self.level).await {
120 Ok(c) => c,
121 Err(_) => {
122 res.body(body_bytes);
123 return;
124 }
125 };
126
127 debug!(
128 encoding = encoding.as_str(),
129 original = body_bytes.len(),
130 compressed = compressed.len(),
131 "compressed response"
132 );
133
134 res.headers_mut()
136 .insert(CONTENT_ENCODING, encoding.as_str().parse().unwrap());
137 res.headers_mut().remove(CONTENT_LENGTH);
138 res.headers_mut()
139 .insert(CONTENT_LENGTH, compressed.len().into());
140
141 res.body(compressed);
142 }
143}
144
145pub async fn collect_res_body_bytes(body: ResBody) -> Result<Vec<u8>, ()> {
147 collect_res_body(body).await
148}
149
150async fn collect_res_body(body: ResBody) -> Result<Vec<u8>, ()> {
152 match body {
153 ResBody::None => Ok(Vec::new()),
154 ResBody::Once(bytes) => Ok(bytes.to_vec()),
155 ResBody::Boxed(boxed) => {
156 let collected = boxed.collect().await.map_err(|_| ())?;
157 Ok(collected.to_bytes().to_vec())
158 }
159 other => {
160 use http_body::Body;
162 let mut buf = Vec::new();
163 let mut pinned = Box::pin(other);
164 loop {
165 match std::future::poll_fn(|cx| pinned.as_mut().poll_frame(cx)).await {
166 Some(Ok(frame)) => {
167 if let Ok(data) = frame.into_data() {
168 buf.extend_from_slice(&data);
169 }
170 }
171 Some(Err(_)) => return Err(()),
172 None => break,
173 }
174 }
175 Ok(buf)
176 }
177 }
178}
179
180fn choose_encoding(headers: &http::HeaderMap, enabled: &[Encoding]) -> Option<Encoding> {
182 let accept = headers.get(ACCEPT_ENCODING)?.to_str().ok()?;
183 for enc in enabled {
186 let token = enc.as_str();
187 if accept.contains(token) || accept.contains("*") {
188 return Some(*enc);
189 }
190 }
191 None
192}
193
194fn is_compressible_content_type(headers: &http::HeaderMap) -> bool {
196 let ct = match headers.get(CONTENT_TYPE) {
197 Some(v) => match v.to_str() {
198 Ok(s) => s.to_ascii_lowercase(),
199 Err(_) => return false,
200 },
201 None => return false,
203 };
204
205 if ct.starts_with("text/") {
207 return true;
208 }
209
210 const COMPRESSIBLE: &[&str] = &[
211 "application/json",
212 "application/javascript",
213 "application/xml",
214 "application/xhtml+xml",
215 "application/rss+xml",
216 "application/atom+xml",
217 "application/wasm",
218 "application/manifest+json",
219 "application/ld+json",
220 "application/graphql+json",
221 "application/geo+json",
222 "application/vnd.api+json",
223 "image/svg+xml",
224 ];
225
226 for mime in COMPRESSIBLE {
227 if ct.starts_with(mime) {
228 return true;
229 }
230 }
231
232 false
233}
234
235async fn compress_bytes(
237 data: &[u8],
238 encoding: Encoding,
239 level: Option<u32>,
240) -> Result<Vec<u8>, crate::ProxyError> {
241 let cursor = Cursor::new(data);
242 let reader = tokio::io::BufReader::new(cursor);
243 let mut output = Vec::new();
244 let compress_level = level.map(|l| Level::Precise(l as i32)).unwrap_or_default();
245
246 match encoding {
247 Encoding::Gzip => {
248 let mut encoder = GzipEncoder::with_quality(reader, compress_level);
249 encoder
250 .read_to_end(&mut output)
251 .await
252 .map_err(|e| crate::ProxyError::Internal(format!("gzip compression error: {e}")))?;
253 }
254 Encoding::Zstd => {
255 let mut encoder = ZstdEncoder::with_quality(reader, compress_level);
256 encoder
257 .read_to_end(&mut output)
258 .await
259 .map_err(|e| crate::ProxyError::Internal(format!("zstd compression error: {e}")))?;
260 }
261 Encoding::Brotli => {
262 let mut encoder = BrotliEncoder::with_quality(reader, compress_level);
263 encoder.read_to_end(&mut output).await.map_err(|e| {
264 crate::ProxyError::Internal(format!("brotli compression error: {e}"))
265 })?;
266 }
267 Encoding::Deflate => {
268 let mut encoder = DeflateEncoder::with_quality(reader, compress_level);
269 encoder.read_to_end(&mut output).await.map_err(|e| {
270 crate::ProxyError::Internal(format!("deflate compression error: {e}"))
271 })?;
272 }
273 }
274
275 Ok(output)
276}