phantom_frame/
compression.rs1use crate::CompressStrategy;
2use anyhow::{anyhow, bail, Result};
3use axum::http::HeaderMap;
4use brotli::{CompressorWriter, Decompressor};
5use flate2::{
6 read::{GzDecoder, ZlibDecoder},
7 write::{GzEncoder, ZlibEncoder},
8 Compression,
9};
10use std::io::{Read, Write};
11
12#[derive(Clone, Copy, Debug, PartialEq, Eq)]
13pub enum ContentEncoding {
14 Brotli,
15 Gzip,
16 Deflate,
17}
18
19impl ContentEncoding {
20 pub fn as_header_value(self) -> &'static str {
21 match self {
22 Self::Brotli => "br",
23 Self::Gzip => "gzip",
24 Self::Deflate => "deflate",
25 }
26 }
27
28 pub fn from_header_value(value: &str) -> Option<Self> {
29 match value.trim().to_ascii_lowercase().as_str() {
30 "br" | "brotli" => Some(Self::Brotli),
31 "gzip" | "x-gzip" => Some(Self::Gzip),
32 "deflate" => Some(Self::Deflate),
33 _ => None,
34 }
35 }
36}
37
38pub fn configured_encoding(strategy: &CompressStrategy) -> Option<ContentEncoding> {
39 match strategy {
40 CompressStrategy::None => None,
41 CompressStrategy::Brotli => Some(ContentEncoding::Brotli),
42 CompressStrategy::Gzip => Some(ContentEncoding::Gzip),
43 CompressStrategy::Deflate => Some(ContentEncoding::Deflate),
44 }
45}
46
47pub fn compress_body(body: &[u8], encoding: ContentEncoding) -> Result<Vec<u8>> {
48 match encoding {
49 ContentEncoding::Brotli => {
50 let mut output = Vec::new();
51 {
52 let mut writer = CompressorWriter::new(&mut output, 4096, 5, 22);
53 writer.write_all(body)?;
54 writer.flush()?;
55 }
56 Ok(output)
57 }
58 ContentEncoding::Gzip => {
59 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
60 encoder.write_all(body)?;
61 Ok(encoder.finish()?)
62 }
63 ContentEncoding::Deflate => {
64 let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
65 encoder.write_all(body)?;
66 Ok(encoder.finish()?)
67 }
68 }
69}
70
71pub fn decompress_body(body: &[u8], encoding: ContentEncoding) -> Result<Vec<u8>> {
72 let mut output = Vec::new();
73
74 match encoding {
75 ContentEncoding::Brotli => {
76 let mut decoder = Decompressor::new(body, 4096);
77 decoder.read_to_end(&mut output)?;
78 }
79 ContentEncoding::Gzip => {
80 let mut decoder = GzDecoder::new(body);
81 decoder.read_to_end(&mut output)?;
82 }
83 ContentEncoding::Deflate => {
84 let mut decoder = ZlibDecoder::new(body);
85 decoder.read_to_end(&mut output)?;
86 }
87 }
88
89 Ok(output)
90}
91
92pub fn decode_upstream_body(body: &[u8], content_encoding: Option<&str>) -> Result<Vec<u8>> {
93 let Some(content_encoding) = content_encoding else {
94 return Ok(body.to_vec());
95 };
96
97 let normalized = content_encoding.trim();
98 if normalized.is_empty() || normalized.eq_ignore_ascii_case("identity") {
99 return Ok(body.to_vec());
100 }
101
102 let encodings: Vec<&str> = normalized
103 .split(',')
104 .map(str::trim)
105 .filter(|value| !value.is_empty())
106 .collect();
107
108 if encodings.len() != 1 {
109 bail!("unsupported upstream content-encoding chain: {normalized}");
110 }
111
112 let encoding = ContentEncoding::from_header_value(encodings[0])
113 .ok_or_else(|| anyhow!("unsupported upstream content-encoding: {}", encodings[0]))?;
114
115 decompress_body(body, encoding)
116}
117
118pub fn client_accepts_encoding(headers: &HeaderMap, encoding: ContentEncoding) -> bool {
119 let Some(value) = headers.get(axum::http::header::ACCEPT_ENCODING) else {
120 return false;
121 };
122 let Ok(value) = value.to_str() else {
123 return false;
124 };
125
126 encoding_quality(value, encoding.as_header_value()) > 0.0
127}
128
129pub fn identity_acceptable(headers: &HeaderMap) -> bool {
130 let Some(value) = headers.get(axum::http::header::ACCEPT_ENCODING) else {
131 return true;
132 };
133 let Ok(value) = value.to_str() else {
134 return true;
135 };
136
137 let identity_quality = token_quality(value, "identity");
138 if let Some(quality) = identity_quality {
139 return quality > 0.0;
140 }
141
142 match token_quality(value, "*") {
143 Some(quality) => quality > 0.0,
144 None => true,
145 }
146}
147
148fn encoding_quality(value: &str, encoding: &str) -> f32 {
149 if let Some(quality) = token_quality(value, encoding) {
150 return quality;
151 }
152
153 token_quality(value, "*").unwrap_or(0.0)
154}
155
156fn token_quality(value: &str, token_name: &str) -> Option<f32> {
157 value.split(',').find_map(|item| {
158 let mut segments = item.trim().split(';');
159 let token = segments.next()?.trim();
160 if !token.eq_ignore_ascii_case(token_name) {
161 return None;
162 }
163
164 let quality = segments
165 .find_map(|segment| {
166 let mut key_value = segment.trim().splitn(2, '=');
167 let key = key_value.next()?.trim();
168 let raw_value = key_value.next()?.trim();
169 if !key.eq_ignore_ascii_case("q") {
170 return None;
171 }
172 raw_value.parse::<f32>().ok()
173 })
174 .unwrap_or(1.0);
175
176 Some(quality.clamp(0.0, 1.0))
177 })
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use axum::http::{HeaderMap, HeaderValue};
184
185 #[test]
186 fn test_configured_encoding() {
187 assert_eq!(configured_encoding(&CompressStrategy::None), None);
188 assert_eq!(
189 configured_encoding(&CompressStrategy::Brotli),
190 Some(ContentEncoding::Brotli)
191 );
192 assert_eq!(
193 configured_encoding(&CompressStrategy::Gzip),
194 Some(ContentEncoding::Gzip)
195 );
196 assert_eq!(
197 configured_encoding(&CompressStrategy::Deflate),
198 Some(ContentEncoding::Deflate)
199 );
200 }
201
202 #[test]
203 fn test_round_trip_compression() {
204 let body = b"phantom-frame compression test body";
205
206 for encoding in [
207 ContentEncoding::Brotli,
208 ContentEncoding::Gzip,
209 ContentEncoding::Deflate,
210 ] {
211 let compressed = compress_body(body, encoding).unwrap();
212 let decompressed = decompress_body(&compressed, encoding).unwrap();
213 assert_eq!(decompressed, body);
214 }
215 }
216
217 #[test]
218 fn test_decode_upstream_identity() {
219 let body = b"plain";
220 assert_eq!(decode_upstream_body(body, None).unwrap(), body);
221 assert_eq!(decode_upstream_body(body, Some("identity")).unwrap(), body);
222 }
223
224 #[test]
225 fn test_client_accepts_encoding_with_q_values() {
226 let mut headers = HeaderMap::new();
227 headers.insert(
228 axum::http::header::ACCEPT_ENCODING,
229 HeaderValue::from_static("gzip;q=0.5, br;q=1.0"),
230 );
231
232 assert!(client_accepts_encoding(&headers, ContentEncoding::Brotli));
233 assert!(client_accepts_encoding(&headers, ContentEncoding::Gzip));
234 assert!(!client_accepts_encoding(&headers, ContentEncoding::Deflate));
235 }
236
237 #[test]
238 fn test_identity_acceptable() {
239 let mut headers = HeaderMap::new();
240 assert!(identity_acceptable(&headers));
241
242 headers.insert(
243 axum::http::header::ACCEPT_ENCODING,
244 HeaderValue::from_static("gzip, br"),
245 );
246 assert!(identity_acceptable(&headers));
247
248 headers.insert(
249 axum::http::header::ACCEPT_ENCODING,
250 HeaderValue::from_static("gzip, identity;q=0"),
251 );
252 assert!(!identity_acceptable(&headers));
253
254 headers.insert(
255 axum::http::header::ACCEPT_ENCODING,
256 HeaderValue::from_static("*;q=0"),
257 );
258 assert!(!identity_acceptable(&headers));
259 }
260}