Skip to main content

phantom_frame/
compression.rs

1use crate::CompressStrategy;
2use anyhow::{anyhow, bail, Result};
3use axum::http::HeaderMap;
4use brotli::{CompressorWriter, Decompressor};
5use flate2::{
6    read::{GzDecoder, ZlibDecoder},
7    write::{GzEncoder, ZlibEncoder},
8    Compression,
9};
10use std::io::{Read, Write};
11
12#[derive(Clone, Copy, Debug, PartialEq, Eq)]
13pub enum ContentEncoding {
14    Brotli,
15    Gzip,
16    Deflate,
17}
18
19impl ContentEncoding {
20    pub fn as_header_value(self) -> &'static str {
21        match self {
22            Self::Brotli => "br",
23            Self::Gzip => "gzip",
24            Self::Deflate => "deflate",
25        }
26    }
27
28    pub fn from_header_value(value: &str) -> Option<Self> {
29        match value.trim().to_ascii_lowercase().as_str() {
30            "br" | "brotli" => Some(Self::Brotli),
31            "gzip" | "x-gzip" => Some(Self::Gzip),
32            "deflate" => Some(Self::Deflate),
33            _ => None,
34        }
35    }
36}
37
38pub fn configured_encoding(strategy: &CompressStrategy) -> Option<ContentEncoding> {
39    match strategy {
40        CompressStrategy::None => None,
41        CompressStrategy::Brotli => Some(ContentEncoding::Brotli),
42        CompressStrategy::Gzip => Some(ContentEncoding::Gzip),
43        CompressStrategy::Deflate => Some(ContentEncoding::Deflate),
44    }
45}
46
47pub fn compress_body(body: &[u8], encoding: ContentEncoding) -> Result<Vec<u8>> {
48    match encoding {
49        ContentEncoding::Brotli => {
50            let mut output = Vec::new();
51            {
52                let mut writer = CompressorWriter::new(&mut output, 4096, 5, 22);
53                writer.write_all(body)?;
54                writer.flush()?;
55            }
56            Ok(output)
57        }
58        ContentEncoding::Gzip => {
59            let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
60            encoder.write_all(body)?;
61            Ok(encoder.finish()?)
62        }
63        ContentEncoding::Deflate => {
64            let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
65            encoder.write_all(body)?;
66            Ok(encoder.finish()?)
67        }
68    }
69}
70
71pub fn decompress_body(body: &[u8], encoding: ContentEncoding) -> Result<Vec<u8>> {
72    let mut output = Vec::new();
73
74    match encoding {
75        ContentEncoding::Brotli => {
76            let mut decoder = Decompressor::new(body, 4096);
77            decoder.read_to_end(&mut output)?;
78        }
79        ContentEncoding::Gzip => {
80            let mut decoder = GzDecoder::new(body);
81            decoder.read_to_end(&mut output)?;
82        }
83        ContentEncoding::Deflate => {
84            let mut decoder = ZlibDecoder::new(body);
85            decoder.read_to_end(&mut output)?;
86        }
87    }
88
89    Ok(output)
90}
91
92pub fn decode_upstream_body(body: &[u8], content_encoding: Option<&str>) -> Result<Vec<u8>> {
93    let Some(content_encoding) = content_encoding else {
94        return Ok(body.to_vec());
95    };
96
97    let normalized = content_encoding.trim();
98    if normalized.is_empty() || normalized.eq_ignore_ascii_case("identity") {
99        return Ok(body.to_vec());
100    }
101
102    let encodings: Vec<&str> = normalized
103        .split(',')
104        .map(str::trim)
105        .filter(|value| !value.is_empty())
106        .collect();
107
108    if encodings.len() != 1 {
109        bail!("unsupported upstream content-encoding chain: {normalized}");
110    }
111
112    let encoding = ContentEncoding::from_header_value(encodings[0])
113        .ok_or_else(|| anyhow!("unsupported upstream content-encoding: {}", encodings[0]))?;
114
115    decompress_body(body, encoding)
116}
117
118pub fn client_accepts_encoding(headers: &HeaderMap, encoding: ContentEncoding) -> bool {
119    let Some(value) = headers.get(axum::http::header::ACCEPT_ENCODING) else {
120        return false;
121    };
122    let Ok(value) = value.to_str() else {
123        return false;
124    };
125
126    encoding_quality(value, encoding.as_header_value()) > 0.0
127}
128
129pub fn identity_acceptable(headers: &HeaderMap) -> bool {
130    let Some(value) = headers.get(axum::http::header::ACCEPT_ENCODING) else {
131        return true;
132    };
133    let Ok(value) = value.to_str() else {
134        return true;
135    };
136
137    let identity_quality = token_quality(value, "identity");
138    if let Some(quality) = identity_quality {
139        return quality > 0.0;
140    }
141
142    match token_quality(value, "*") {
143        Some(quality) => quality > 0.0,
144        None => true,
145    }
146}
147
148fn encoding_quality(value: &str, encoding: &str) -> f32 {
149    if let Some(quality) = token_quality(value, encoding) {
150        return quality;
151    }
152
153    token_quality(value, "*").unwrap_or(0.0)
154}
155
156fn token_quality(value: &str, token_name: &str) -> Option<f32> {
157    value.split(',').find_map(|item| {
158        let mut segments = item.trim().split(';');
159        let token = segments.next()?.trim();
160        if !token.eq_ignore_ascii_case(token_name) {
161            return None;
162        }
163
164        let quality = segments
165            .find_map(|segment| {
166                let mut key_value = segment.trim().splitn(2, '=');
167                let key = key_value.next()?.trim();
168                let raw_value = key_value.next()?.trim();
169                if !key.eq_ignore_ascii_case("q") {
170                    return None;
171                }
172                raw_value.parse::<f32>().ok()
173            })
174            .unwrap_or(1.0);
175
176        Some(quality.clamp(0.0, 1.0))
177    })
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use axum::http::{HeaderMap, HeaderValue};
184
185    #[test]
186    fn test_configured_encoding() {
187        assert_eq!(configured_encoding(&CompressStrategy::None), None);
188        assert_eq!(
189            configured_encoding(&CompressStrategy::Brotli),
190            Some(ContentEncoding::Brotli)
191        );
192        assert_eq!(
193            configured_encoding(&CompressStrategy::Gzip),
194            Some(ContentEncoding::Gzip)
195        );
196        assert_eq!(
197            configured_encoding(&CompressStrategy::Deflate),
198            Some(ContentEncoding::Deflate)
199        );
200    }
201
202    #[test]
203    fn test_round_trip_compression() {
204        let body = b"phantom-frame compression test body";
205
206        for encoding in [
207            ContentEncoding::Brotli,
208            ContentEncoding::Gzip,
209            ContentEncoding::Deflate,
210        ] {
211            let compressed = compress_body(body, encoding).unwrap();
212            let decompressed = decompress_body(&compressed, encoding).unwrap();
213            assert_eq!(decompressed, body);
214        }
215    }
216
217    #[test]
218    fn test_decode_upstream_identity() {
219        let body = b"plain";
220        assert_eq!(decode_upstream_body(body, None).unwrap(), body);
221        assert_eq!(decode_upstream_body(body, Some("identity")).unwrap(), body);
222    }
223
224    #[test]
225    fn test_client_accepts_encoding_with_q_values() {
226        let mut headers = HeaderMap::new();
227        headers.insert(
228            axum::http::header::ACCEPT_ENCODING,
229            HeaderValue::from_static("gzip;q=0.5, br;q=1.0"),
230        );
231
232        assert!(client_accepts_encoding(&headers, ContentEncoding::Brotli));
233        assert!(client_accepts_encoding(&headers, ContentEncoding::Gzip));
234        assert!(!client_accepts_encoding(&headers, ContentEncoding::Deflate));
235    }
236
237    #[test]
238    fn test_identity_acceptable() {
239        let mut headers = HeaderMap::new();
240        assert!(identity_acceptable(&headers));
241
242        headers.insert(
243            axum::http::header::ACCEPT_ENCODING,
244            HeaderValue::from_static("gzip, br"),
245        );
246        assert!(identity_acceptable(&headers));
247
248        headers.insert(
249            axum::http::header::ACCEPT_ENCODING,
250            HeaderValue::from_static("gzip, identity;q=0"),
251        );
252        assert!(!identity_acceptable(&headers));
253
254        headers.insert(
255            axum::http::header::ACCEPT_ENCODING,
256            HeaderValue::from_static("*;q=0"),
257        );
258        assert!(!identity_acceptable(&headers));
259    }
260}