Skip to main content

http_encoding/
coding.rs

1use http::{
2    HeaderValue, Response, StatusCode,
3    header::{self, ACCEPT_ENCODING, HeaderMap, HeaderName},
4};
5use http_body_alt::Body;
6
7use super::{
8    coder::{Coder, FeaturedCode},
9    error::FeatureError,
10};
11
12/// Represents a supported content encoding.
13#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
14pub enum ContentEncoding {
15    /// A format using the Brotli algorithm.
16    Br,
17    /// A format using the zlib structure with deflate algorithm.
18    Deflate,
19    /// Gzip algorithm.
20    Gzip,
21    /// Zstandard algorithm.
22    Zstd,
23    /// Indicates no operation is done with encoding.
24    #[default]
25    Identity,
26}
27
28impl ContentEncoding {
29    pub const fn as_header_value(&self) -> HeaderValue {
30        match self {
31            Self::Br => HeaderValue::from_static("br"),
32            Self::Deflate => HeaderValue::from_static("deflate"),
33            Self::Gzip => HeaderValue::from_static("gzip"),
34            Self::Zstd => HeaderValue::from_static("zstd"),
35            Self::Identity => HeaderValue::from_static("identity"),
36        }
37    }
38
39    /// Negotiate content encoding from the `accept-encoding` header.
40    ///
41    /// Returns the highest q-value encoding that is supported by enabled crate features.
42    /// Returns [`ContentEncoding::NoOp`] if no supported encoding is found.
43    pub fn from_headers(headers: &HeaderMap) -> Self {
44        let mut prefer = ContentEncodingWithQValue::default();
45
46        for encoding in Self::_from_headers(headers, &ACCEPT_ENCODING) {
47            prefer.try_update(encoding);
48        }
49
50        prefer.enc
51    }
52
53    /// Negotiate content encoding from a caller-specified header name.
54    ///
55    /// Same as [`from_headers`](Self::from_headers) but reads from `needle` instead of
56    /// `accept-encoding`. Useful for protocols that use a different header for encoding
57    /// negotiation (e.g. `grpc-accept-encoding` for gRPC).
58    pub fn from_headers_with(headers: &HeaderMap, needle: &HeaderName) -> Self {
59        let mut prefer = ContentEncodingWithQValue::default();
60
61        for encoding in Self::_from_headers(headers, needle) {
62            prefer.try_update(encoding);
63        }
64
65        prefer.enc
66    }
67
68    /// Encode a response body, updating response headers accordingly.
69    ///
70    /// Skips encoding if the response already has a `Content-Encoding` header, or if the
71    /// status is `101 Switching Protocols` or `204 No Content`.
72    ///
73    /// # Headers
74    /// Sets `Content-Encoding` and removes `Content-Length` when compression is applied.
75    /// For HTTP/1.1 responses, `Transfer-Encoding: chunked` is appended. Callers should avoid
76    /// modifying `Content-Encoding` or `Transfer-Encoding` headers after calling this method,
77    /// as inconsistent values may cause incorrect behavior in downstream clients.
78    pub fn try_encode<S>(mut self, response: Response<S>) -> Response<Coder<S, FeaturedCode>>
79    where
80        S: Body,
81        S::Data: AsRef<[u8]> + 'static,
82    {
83        #[allow(unused_mut)]
84        let (mut parts, body) = response.into_parts();
85
86        if parts.headers.contains_key(&header::CONTENT_ENCODING)
87            || parts.status == StatusCode::SWITCHING_PROTOCOLS
88            || parts.status == StatusCode::NO_CONTENT
89        {
90            self = ContentEncoding::Identity;
91        }
92
93        self.update_header(&mut parts.headers, parts.version);
94
95        let body = self.encode_body(body);
96        Response::from_parts(parts, body)
97    }
98
99    /// Encode a [`Body`] with featured encoder
100    pub fn encode_body<S>(self, body: S) -> Coder<S, FeaturedCode>
101    where
102        S: Body,
103        S::Data: AsRef<[u8]> + 'static,
104    {
105        let encoder = match self {
106            #[cfg(feature = "de")]
107            ContentEncoding::Deflate => FeaturedCode::EncodeDe(super::deflate::Encoder::new(
108                super::writer::BytesMutWriter::new(),
109                flate2::Compression::fast(),
110            )),
111            #[cfg(feature = "gz")]
112            ContentEncoding::Gzip => FeaturedCode::EncodeGz(super::gzip::Encoder::new(
113                super::writer::BytesMutWriter::new(),
114                flate2::Compression::fast(),
115            )),
116            #[cfg(feature = "br")]
117            ContentEncoding::Br => FeaturedCode::EncodeBr(super::brotli::Encoder::new(3)),
118            #[cfg(feature = "zs")]
119            ContentEncoding::Zstd => FeaturedCode::EncodeZs(super::zstandard::Encoder::new(3)),
120            _ => FeaturedCode::default(),
121        };
122
123        Coder::new(body, encoder)
124    }
125
126    /// Decode a [`Body`] with featured decoder.
127    ///
128    /// Symmetric to [`encode_body`](Self::encode_body). Use this when decoding outside of
129    /// an HTTP response context (e.g. gRPC request decompression) where the encoding is
130    /// determined by a protocol-specific header rather than `Content-Encoding`.
131    pub fn decode_body<S>(self, body: S) -> Coder<S, FeaturedCode>
132    where
133        S: Body,
134        S::Data: AsRef<[u8]> + 'static,
135    {
136        let decoder = match self {
137            #[cfg(feature = "de")]
138            ContentEncoding::Deflate => {
139                FeaturedCode::DecodeDe(super::deflate::Decoder::new(super::writer::BytesMutWriter::new()))
140            }
141            #[cfg(feature = "gz")]
142            ContentEncoding::Gzip => {
143                FeaturedCode::DecodeGz(super::gzip::Decoder::new(super::writer::BytesMutWriter::new()))
144            }
145            #[cfg(feature = "br")]
146            ContentEncoding::Br => FeaturedCode::DecodeBr(super::brotli::Decoder::new()),
147            #[cfg(feature = "zs")]
148            ContentEncoding::Zstd => FeaturedCode::DecodeZs(super::zstandard::Decoder::new()),
149            _ => FeaturedCode::default(),
150        };
151
152        Coder::new(body, decoder)
153    }
154
155    fn update_header(self, headers: &mut HeaderMap, version: http::Version) {
156        if matches!(self, ContentEncoding::Identity) {
157            return;
158        }
159
160        headers.insert(header::CONTENT_ENCODING, self.as_header_value());
161        headers.remove(header::CONTENT_LENGTH);
162
163        // Connection specific headers are not allowed in HTTP/2 and later versions.
164        // see https://datatracker.ietf.org/doc/html/rfc7540#section-8.1.2.2
165        if version == http::Version::HTTP_11 {
166            headers.append(header::TRANSFER_ENCODING, header::HeaderValue::from_static("chunked"));
167        }
168    }
169
170    fn _from_headers(headers: &HeaderMap, needle: &HeaderName) -> impl Iterator<Item = ContentEncodingWithQValue> {
171        headers
172            .get_all(needle)
173            .iter()
174            .filter_map(|hval| hval.to_str().ok())
175            .flat_map(|s| s.split(','))
176            .filter_map(|v| {
177                let mut v = v.splitn(2, ';');
178                Self::try_parse(v.next().unwrap().trim()).ok().map(|enc| {
179                    let val = v
180                        .next()
181                        .and_then(|v| QValue::parse(v.trim()))
182                        .unwrap_or_else(QValue::one);
183                    ContentEncodingWithQValue { enc, val }
184                })
185            })
186    }
187
188    pub(super) fn try_parse(s: &str) -> Result<Self, FeatureError> {
189        if s.eq_ignore_ascii_case("gzip") {
190            Ok(Self::Gzip)
191        } else if s.eq_ignore_ascii_case("deflate") {
192            Ok(Self::Deflate)
193        } else if s.eq_ignore_ascii_case("br") {
194            Ok(Self::Br)
195        } else if s.eq_ignore_ascii_case("zstd") {
196            Ok(Self::Zstd)
197        } else if s.eq_ignore_ascii_case("identity") {
198            Ok(Self::Identity)
199        } else {
200            Err(FeatureError::Unknown(s.to_string().into_boxed_str()))
201        }
202    }
203}
204
205struct ContentEncodingWithQValue {
206    enc: ContentEncoding,
207    val: QValue,
208}
209
210impl Default for ContentEncodingWithQValue {
211    fn default() -> Self {
212        Self {
213            enc: ContentEncoding::Identity,
214            val: QValue::zero(),
215        }
216    }
217}
218
219impl ContentEncodingWithQValue {
220    fn try_update(&mut self, other: Self) {
221        if other.val > self.val {
222            match other.enc {
223                #[cfg(not(feature = "br"))]
224                ContentEncoding::Br => return,
225                #[cfg(not(feature = "de"))]
226                ContentEncoding::Deflate => return,
227                #[cfg(not(feature = "gz"))]
228                ContentEncoding::Gzip => return,
229                #[cfg(not(feature = "zs"))]
230                ContentEncoding::Zstd => return,
231                _ => {}
232            };
233            *self = other;
234        }
235    }
236}
237
238#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
239struct QValue(u16);
240
241impl QValue {
242    const fn zero() -> Self {
243        Self(0)
244    }
245
246    const fn one() -> Self {
247        Self(1000)
248    }
249
250    // Parse a q-value as specified in RFC 7231 section 5.3.1.
251    fn parse(s: &str) -> Option<Self> {
252        let mut c = s.chars();
253        // Parse "q=" (case-insensitively).
254        match c.next() {
255            Some('q') | Some('Q') => (),
256            _ => return None,
257        };
258        match c.next() {
259            Some('=') => (),
260            _ => return None,
261        };
262
263        // Parse leading digit. Since valid q-values are between 0.000 and 1.000, only "0" and "1"
264        // are allowed.
265        let mut value = match c.next() {
266            Some('0') => 0,
267            Some('1') => 1000,
268            _ => return None,
269        };
270
271        // Parse optional decimal point.
272        match c.next() {
273            Some('.') => (),
274            None => return Some(Self(value)),
275            _ => return None,
276        };
277
278        // Parse optional fractional digits. The value of each digit is multiplied by `factor`.
279        // Since the q-value is represented as an integer between 0 and 1000, `factor` is `100` for
280        // the first digit, `10` for the next, and `1` for the digit after that.
281        let mut factor = 100;
282        loop {
283            match c.next() {
284                Some(n @ '0'..='9') => {
285                    // If `factor` is less than `1`, three digits have already been parsed. A
286                    // q-value having more than 3 fractional digits is invalid.
287                    if factor < 1 {
288                        return None;
289                    }
290                    // Add the digit's value multiplied by `factor` to `value`.
291                    value += factor * (n as u16 - '0' as u16);
292                }
293                None => {
294                    // No more characters to parse. Check that the value representing the q-value is
295                    // in the valid range.
296                    return if value <= 1000 { Some(Self(value)) } else { None };
297                }
298                _ => return None,
299            };
300            factor /= 10;
301        }
302    }
303}