1use crate::CompressStrategy;
2use anyhow::{anyhow, bail, Result};
3use axum::http::HeaderMap;
4use brotli::{CompressorWriter, Decompressor};
5use flate2::{
6 read::{GzDecoder, ZlibDecoder},
7 write::{GzEncoder, ZlibEncoder},
8 Compression,
9};
10use std::io::{Read, Write};
11use tokio::task;
12
13#[derive(Clone, Copy, Debug, PartialEq, Eq)]
14pub enum ContentEncoding {
15 Brotli,
16 Gzip,
17 Deflate,
18}
19
20impl ContentEncoding {
21 pub fn as_header_value(self) -> &'static str {
22 match self {
23 Self::Brotli => "br",
24 Self::Gzip => "gzip",
25 Self::Deflate => "deflate",
26 }
27 }
28
29 pub fn from_header_value(value: &str) -> Option<Self> {
30 match value.trim().to_ascii_lowercase().as_str() {
31 "br" | "brotli" => Some(Self::Brotli),
32 "gzip" | "x-gzip" => Some(Self::Gzip),
33 "deflate" => Some(Self::Deflate),
34 _ => None,
35 }
36 }
37}
38
39pub fn configured_encoding(strategy: &CompressStrategy) -> Option<ContentEncoding> {
40 match strategy {
41 CompressStrategy::None => None,
42 CompressStrategy::Brotli => Some(ContentEncoding::Brotli),
43 CompressStrategy::Gzip => Some(ContentEncoding::Gzip),
44 CompressStrategy::Deflate => Some(ContentEncoding::Deflate),
45 }
46}
47
48pub fn compress_body(body: &[u8], encoding: ContentEncoding) -> Result<Vec<u8>> {
49 match encoding {
50 ContentEncoding::Brotli => {
51 let mut output = Vec::new();
52 {
53 let mut writer = CompressorWriter::new(&mut output, 4096, 5, 22);
54 writer.write_all(body)?;
55 writer.flush()?;
56 }
57 Ok(output)
58 }
59 ContentEncoding::Gzip => {
60 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
61 encoder.write_all(body)?;
62 Ok(encoder.finish()?)
63 }
64 ContentEncoding::Deflate => {
65 let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
66 encoder.write_all(body)?;
67 Ok(encoder.finish()?)
68 }
69 }
70}
71
72pub async fn compress_body_async(body: Vec<u8>, encoding: ContentEncoding) -> Result<Vec<u8>> {
73 task::spawn_blocking(move || compress_body(&body, encoding))
74 .await
75 .map_err(|error| anyhow!("compression task failed: {}", error))?
76}
77
78pub fn decompress_body(body: &[u8], encoding: ContentEncoding) -> Result<Vec<u8>> {
79 let mut output = Vec::new();
80
81 match encoding {
82 ContentEncoding::Brotli => {
83 let mut decoder = Decompressor::new(body, 4096);
84 decoder.read_to_end(&mut output)?;
85 }
86 ContentEncoding::Gzip => {
87 let mut decoder = GzDecoder::new(body);
88 decoder.read_to_end(&mut output)?;
89 }
90 ContentEncoding::Deflate => {
91 let mut decoder = ZlibDecoder::new(body);
92 decoder.read_to_end(&mut output)?;
93 }
94 }
95
96 Ok(output)
97}
98
99pub async fn decompress_body_async(body: Vec<u8>, encoding: ContentEncoding) -> Result<Vec<u8>> {
100 task::spawn_blocking(move || decompress_body(&body, encoding))
101 .await
102 .map_err(|error| anyhow!("decompression task failed: {}", error))?
103}
104
105pub fn decode_upstream_body(body: &[u8], content_encoding: Option<&str>) -> Result<Vec<u8>> {
106 let Some(content_encoding) = content_encoding else {
107 return Ok(body.to_vec());
108 };
109
110 let normalized = content_encoding.trim();
111 if normalized.is_empty() || normalized.eq_ignore_ascii_case("identity") {
112 return Ok(body.to_vec());
113 }
114
115 let encodings: Vec<&str> = normalized
116 .split(',')
117 .map(str::trim)
118 .filter(|value| !value.is_empty())
119 .collect();
120
121 if encodings.len() != 1 {
122 bail!("unsupported upstream content-encoding chain: {normalized}");
123 }
124
125 let encoding = ContentEncoding::from_header_value(encodings[0])
126 .ok_or_else(|| anyhow!("unsupported upstream content-encoding: {}", encodings[0]))?;
127
128 decompress_body(body, encoding)
129}
130
131pub async fn decode_upstream_body_async(
132 body: Vec<u8>,
133 content_encoding: Option<String>,
134) -> Result<Vec<u8>> {
135 task::spawn_blocking(move || decode_upstream_body(&body, content_encoding.as_deref()))
136 .await
137 .map_err(|error| anyhow!("upstream decode task failed: {}", error))?
138}
139
140pub fn client_accepts_encoding(headers: &HeaderMap, encoding: ContentEncoding) -> bool {
141 let Some(value) = headers.get(axum::http::header::ACCEPT_ENCODING) else {
142 return false;
143 };
144 let Ok(value) = value.to_str() else {
145 return false;
146 };
147
148 encoding_quality(value, encoding.as_header_value()) > 0.0
149}
150
151pub fn identity_acceptable(headers: &HeaderMap) -> bool {
152 let Some(value) = headers.get(axum::http::header::ACCEPT_ENCODING) else {
153 return true;
154 };
155 let Ok(value) = value.to_str() else {
156 return true;
157 };
158
159 let identity_quality = token_quality(value, "identity");
160 if let Some(quality) = identity_quality {
161 return quality > 0.0;
162 }
163
164 match token_quality(value, "*") {
165 Some(quality) => quality > 0.0,
166 None => true,
167 }
168}
169
170fn encoding_quality(value: &str, encoding: &str) -> f32 {
171 if let Some(quality) = token_quality(value, encoding) {
172 return quality;
173 }
174
175 token_quality(value, "*").unwrap_or(0.0)
176}
177
178fn token_quality(value: &str, token_name: &str) -> Option<f32> {
179 value.split(',').find_map(|item| {
180 let mut segments = item.trim().split(';');
181 let token = segments.next()?.trim();
182 if !token.eq_ignore_ascii_case(token_name) {
183 return None;
184 }
185
186 let quality = segments
187 .find_map(|segment| {
188 let mut key_value = segment.trim().splitn(2, '=');
189 let key = key_value.next()?.trim();
190 let raw_value = key_value.next()?.trim();
191 if !key.eq_ignore_ascii_case("q") {
192 return None;
193 }
194 raw_value.parse::<f32>().ok()
195 })
196 .unwrap_or(1.0);
197
198 Some(quality.clamp(0.0, 1.0))
199 })
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use axum::http::{HeaderMap, HeaderValue};
206
207 #[test]
208 fn test_configured_encoding() {
209 assert_eq!(configured_encoding(&CompressStrategy::None), None);
210 assert_eq!(
211 configured_encoding(&CompressStrategy::Brotli),
212 Some(ContentEncoding::Brotli)
213 );
214 assert_eq!(
215 configured_encoding(&CompressStrategy::Gzip),
216 Some(ContentEncoding::Gzip)
217 );
218 assert_eq!(
219 configured_encoding(&CompressStrategy::Deflate),
220 Some(ContentEncoding::Deflate)
221 );
222 }
223
224 #[test]
225 fn test_round_trip_compression() {
226 let body = b"phantom-frame compression test body";
227
228 for encoding in [
229 ContentEncoding::Brotli,
230 ContentEncoding::Gzip,
231 ContentEncoding::Deflate,
232 ] {
233 let compressed = compress_body(body, encoding).unwrap();
234 let decompressed = decompress_body(&compressed, encoding).unwrap();
235 assert_eq!(decompressed, body);
236 }
237 }
238
239 #[test]
240 fn test_decode_upstream_identity() {
241 let body = b"plain";
242 assert_eq!(decode_upstream_body(body, None).unwrap(), body);
243 assert_eq!(decode_upstream_body(body, Some("identity")).unwrap(), body);
244 }
245
246 #[test]
247 fn test_client_accepts_encoding_with_q_values() {
248 let mut headers = HeaderMap::new();
249 headers.insert(
250 axum::http::header::ACCEPT_ENCODING,
251 HeaderValue::from_static("gzip;q=0.5, br;q=1.0"),
252 );
253
254 assert!(client_accepts_encoding(&headers, ContentEncoding::Brotli));
255 assert!(client_accepts_encoding(&headers, ContentEncoding::Gzip));
256 assert!(!client_accepts_encoding(&headers, ContentEncoding::Deflate));
257 }
258
259 #[test]
260 fn test_identity_acceptable() {
261 let mut headers = HeaderMap::new();
262 assert!(identity_acceptable(&headers));
263
264 headers.insert(
265 axum::http::header::ACCEPT_ENCODING,
266 HeaderValue::from_static("gzip, br"),
267 );
268 assert!(identity_acceptable(&headers));
269
270 headers.insert(
271 axum::http::header::ACCEPT_ENCODING,
272 HeaderValue::from_static("gzip, identity;q=0"),
273 );
274 assert!(!identity_acceptable(&headers));
275
276 headers.insert(
277 axum::http::header::ACCEPT_ENCODING,
278 HeaderValue::from_static("*;q=0"),
279 );
280 assert!(!identity_acceptable(&headers));
281 }
282}