Skip to main content

phantom_frame/
compression.rs

1use crate::CompressStrategy;
2use anyhow::{anyhow, bail, Result};
3use axum::http::HeaderMap;
4use brotli::{CompressorWriter, Decompressor};
5use flate2::{
6    read::{GzDecoder, ZlibDecoder},
7    write::{GzEncoder, ZlibEncoder},
8    Compression,
9};
10use std::io::{Read, Write};
11use tokio::task;
12
13#[derive(Clone, Copy, Debug, PartialEq, Eq)]
14pub enum ContentEncoding {
15    Brotli,
16    Gzip,
17    Deflate,
18}
19
20impl ContentEncoding {
21    pub fn as_header_value(self) -> &'static str {
22        match self {
23            Self::Brotli => "br",
24            Self::Gzip => "gzip",
25            Self::Deflate => "deflate",
26        }
27    }
28
29    pub fn from_header_value(value: &str) -> Option<Self> {
30        match value.trim().to_ascii_lowercase().as_str() {
31            "br" | "brotli" => Some(Self::Brotli),
32            "gzip" | "x-gzip" => Some(Self::Gzip),
33            "deflate" => Some(Self::Deflate),
34            _ => None,
35        }
36    }
37}
38
39pub fn configured_encoding(strategy: &CompressStrategy) -> Option<ContentEncoding> {
40    match strategy {
41        CompressStrategy::None => None,
42        CompressStrategy::Brotli => Some(ContentEncoding::Brotli),
43        CompressStrategy::Gzip => Some(ContentEncoding::Gzip),
44        CompressStrategy::Deflate => Some(ContentEncoding::Deflate),
45    }
46}
47
48pub fn compress_body(body: &[u8], encoding: ContentEncoding) -> Result<Vec<u8>> {
49    match encoding {
50        ContentEncoding::Brotli => {
51            let mut output = Vec::new();
52            {
53                let mut writer = CompressorWriter::new(&mut output, 4096, 5, 22);
54                writer.write_all(body)?;
55                writer.flush()?;
56            }
57            Ok(output)
58        }
59        ContentEncoding::Gzip => {
60            let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
61            encoder.write_all(body)?;
62            Ok(encoder.finish()?)
63        }
64        ContentEncoding::Deflate => {
65            let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
66            encoder.write_all(body)?;
67            Ok(encoder.finish()?)
68        }
69    }
70}
71
72pub async fn compress_body_async(body: Vec<u8>, encoding: ContentEncoding) -> Result<Vec<u8>> {
73    task::spawn_blocking(move || compress_body(&body, encoding))
74        .await
75        .map_err(|error| anyhow!("compression task failed: {}", error))?
76}
77
78pub fn decompress_body(body: &[u8], encoding: ContentEncoding) -> Result<Vec<u8>> {
79    let mut output = Vec::new();
80
81    match encoding {
82        ContentEncoding::Brotli => {
83            let mut decoder = Decompressor::new(body, 4096);
84            decoder.read_to_end(&mut output)?;
85        }
86        ContentEncoding::Gzip => {
87            let mut decoder = GzDecoder::new(body);
88            decoder.read_to_end(&mut output)?;
89        }
90        ContentEncoding::Deflate => {
91            let mut decoder = ZlibDecoder::new(body);
92            decoder.read_to_end(&mut output)?;
93        }
94    }
95
96    Ok(output)
97}
98
99pub async fn decompress_body_async(body: Vec<u8>, encoding: ContentEncoding) -> Result<Vec<u8>> {
100    task::spawn_blocking(move || decompress_body(&body, encoding))
101        .await
102        .map_err(|error| anyhow!("decompression task failed: {}", error))?
103}
104
105pub fn decode_upstream_body(body: &[u8], content_encoding: Option<&str>) -> Result<Vec<u8>> {
106    let Some(content_encoding) = content_encoding else {
107        return Ok(body.to_vec());
108    };
109
110    let normalized = content_encoding.trim();
111    if normalized.is_empty() || normalized.eq_ignore_ascii_case("identity") {
112        return Ok(body.to_vec());
113    }
114
115    let encodings: Vec<&str> = normalized
116        .split(',')
117        .map(str::trim)
118        .filter(|value| !value.is_empty())
119        .collect();
120
121    if encodings.len() != 1 {
122        bail!("unsupported upstream content-encoding chain: {normalized}");
123    }
124
125    let encoding = ContentEncoding::from_header_value(encodings[0])
126        .ok_or_else(|| anyhow!("unsupported upstream content-encoding: {}", encodings[0]))?;
127
128    decompress_body(body, encoding)
129}
130
131pub async fn decode_upstream_body_async(
132    body: Vec<u8>,
133    content_encoding: Option<String>,
134) -> Result<Vec<u8>> {
135    task::spawn_blocking(move || decode_upstream_body(&body, content_encoding.as_deref()))
136        .await
137        .map_err(|error| anyhow!("upstream decode task failed: {}", error))?
138}
139
140pub fn client_accepts_encoding(headers: &HeaderMap, encoding: ContentEncoding) -> bool {
141    let Some(value) = headers.get(axum::http::header::ACCEPT_ENCODING) else {
142        return false;
143    };
144    let Ok(value) = value.to_str() else {
145        return false;
146    };
147
148    encoding_quality(value, encoding.as_header_value()) > 0.0
149}
150
151pub fn identity_acceptable(headers: &HeaderMap) -> bool {
152    let Some(value) = headers.get(axum::http::header::ACCEPT_ENCODING) else {
153        return true;
154    };
155    let Ok(value) = value.to_str() else {
156        return true;
157    };
158
159    let identity_quality = token_quality(value, "identity");
160    if let Some(quality) = identity_quality {
161        return quality > 0.0;
162    }
163
164    match token_quality(value, "*") {
165        Some(quality) => quality > 0.0,
166        None => true,
167    }
168}
169
170fn encoding_quality(value: &str, encoding: &str) -> f32 {
171    if let Some(quality) = token_quality(value, encoding) {
172        return quality;
173    }
174
175    token_quality(value, "*").unwrap_or(0.0)
176}
177
178fn token_quality(value: &str, token_name: &str) -> Option<f32> {
179    value.split(',').find_map(|item| {
180        let mut segments = item.trim().split(';');
181        let token = segments.next()?.trim();
182        if !token.eq_ignore_ascii_case(token_name) {
183            return None;
184        }
185
186        let quality = segments
187            .find_map(|segment| {
188                let mut key_value = segment.trim().splitn(2, '=');
189                let key = key_value.next()?.trim();
190                let raw_value = key_value.next()?.trim();
191                if !key.eq_ignore_ascii_case("q") {
192                    return None;
193                }
194                raw_value.parse::<f32>().ok()
195            })
196            .unwrap_or(1.0);
197
198        Some(quality.clamp(0.0, 1.0))
199    })
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use axum::http::{HeaderMap, HeaderValue};
206
207    #[test]
208    fn test_configured_encoding() {
209        assert_eq!(configured_encoding(&CompressStrategy::None), None);
210        assert_eq!(
211            configured_encoding(&CompressStrategy::Brotli),
212            Some(ContentEncoding::Brotli)
213        );
214        assert_eq!(
215            configured_encoding(&CompressStrategy::Gzip),
216            Some(ContentEncoding::Gzip)
217        );
218        assert_eq!(
219            configured_encoding(&CompressStrategy::Deflate),
220            Some(ContentEncoding::Deflate)
221        );
222    }
223
224    #[test]
225    fn test_round_trip_compression() {
226        let body = b"phantom-frame compression test body";
227
228        for encoding in [
229            ContentEncoding::Brotli,
230            ContentEncoding::Gzip,
231            ContentEncoding::Deflate,
232        ] {
233            let compressed = compress_body(body, encoding).unwrap();
234            let decompressed = decompress_body(&compressed, encoding).unwrap();
235            assert_eq!(decompressed, body);
236        }
237    }
238
239    #[test]
240    fn test_decode_upstream_identity() {
241        let body = b"plain";
242        assert_eq!(decode_upstream_body(body, None).unwrap(), body);
243        assert_eq!(decode_upstream_body(body, Some("identity")).unwrap(), body);
244    }
245
246    #[test]
247    fn test_client_accepts_encoding_with_q_values() {
248        let mut headers = HeaderMap::new();
249        headers.insert(
250            axum::http::header::ACCEPT_ENCODING,
251            HeaderValue::from_static("gzip;q=0.5, br;q=1.0"),
252        );
253
254        assert!(client_accepts_encoding(&headers, ContentEncoding::Brotli));
255        assert!(client_accepts_encoding(&headers, ContentEncoding::Gzip));
256        assert!(!client_accepts_encoding(&headers, ContentEncoding::Deflate));
257    }
258
259    #[test]
260    fn test_identity_acceptable() {
261        let mut headers = HeaderMap::new();
262        assert!(identity_acceptable(&headers));
263
264        headers.insert(
265            axum::http::header::ACCEPT_ENCODING,
266            HeaderValue::from_static("gzip, br"),
267        );
268        assert!(identity_acceptable(&headers));
269
270        headers.insert(
271            axum::http::header::ACCEPT_ENCODING,
272            HeaderValue::from_static("gzip, identity;q=0"),
273        );
274        assert!(!identity_acceptable(&headers));
275
276        headers.insert(
277            axum::http::header::ACCEPT_ENCODING,
278            HeaderValue::from_static("*;q=0"),
279        );
280        assert!(!identity_acceptable(&headers));
281    }
282}