hyper_content_encoding/
lib.rs

1//! Utility for handling Content-Encoding with [`hyper`](https://docs.rs/hyper)
2//!
3//! This crate currently only supports `gzip`, `deflate` and `identity`
4
5// TODO:
6// List of encodings:
7// https://www.iana.org/assignments/http-parameters/http-parameters.xml#http-parameters-1
8
9use bytes::Bytes;
10use core::fmt;
11use flate2::{
12    read::{DeflateDecoder, DeflateEncoder, GzDecoder, GzEncoder, ZlibDecoder},
13    Compression,
14};
15use http::{
16    header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE},
17    HeaderValue, Request, Response, StatusCode,
18};
19use http_body_util::{combinators::BoxBody, BodyExt, Full};
20use hyper::{body::Incoming, service::Service};
21use std::{fmt::Debug, future::Future, pin::Pin};
22use std::{io::prelude::*, str::FromStr};
23
24type Result<T> = std::result::Result<T, HyperContentEncodingError>;
25
26/// The error used in hyper-content-encoding
27#[derive(Debug, Clone)]
28pub struct HyperContentEncodingError {
29    inner: String,
30}
31
32impl std::error::Error for HyperContentEncodingError {}
33
34impl HyperContentEncodingError {
35    pub fn new(inner: String) -> Self {
36        HyperContentEncodingError { inner }
37    }
38}
39
40impl From<String> for HyperContentEncodingError {
41    fn from(value: String) -> Self {
42        HyperContentEncodingError::new(value)
43    }
44}
45
46impl fmt::Display for HyperContentEncodingError {
47    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
48        write!(f, "{}", self.inner)
49    }
50}
51
52fn convert_err<E>(e: E) -> HyperContentEncodingError
53where
54    E: Debug,
55{
56    format!("{:?}", e).into()
57}
58
59trait Decoder<T>: Read
60where
61    T: Read,
62{
63    fn new(_: T) -> Self;
64}
65
66impl<R> Decoder<R> for GzDecoder<R>
67where
68    R: Read,
69{
70    fn new(reader: R) -> Self {
71        GzDecoder::new(reader)
72    }
73}
74
75impl<R> Decoder<R> for DeflateDecoder<R>
76where
77    R: Read,
78{
79    fn new(reader: R) -> Self {
80        DeflateDecoder::new(reader)
81    }
82}
83
84impl<R> Decoder<R> for ZlibDecoder<R>
85where
86    R: Read,
87{
88    fn new(reader: R) -> Self {
89        ZlibDecoder::new(reader)
90    }
91}
92
93fn decompress<'a, T>(body: &'a Bytes) -> Result<String>
94where
95    T: Decoder<&'a [u8]>,
96{
97    let reader: &[u8] = body;
98    let mut decoder = T::new(reader);
99    let mut s = String::new();
100    decoder.read_to_string(&mut s).map_err(convert_err)?;
101    Ok(s)
102}
103
104/// Extracts the body of a response as a String.
105/// The response *must* contain a `Content-Type` with the word `text`
106///
107/// Currently the only handled `Content-Encodings` that are supported are
108/// - `identity`
109/// - `x-gzip`
110/// - `gzip`
111/// - `deflate`
112/// TODO: stream
113pub async fn response_to_string(res: Response<Incoming>) -> Result<String> {
114    if let Some(content_type) = res.headers().get(CONTENT_TYPE) {
115        if content_type.to_str().map_err(convert_err)?.contains("text") {
116            if let Some(encoding) = Encoding::get_response_encoding(&res) {
117                let res = res.map(|b| b.boxed());
118                let body: Bytes = res.collect().await.map_err(convert_err)?.to_bytes();
119
120                match encoding {
121                    Encoding::Gzip => decompress::<GzDecoder<_>>(&body),
122                    Encoding::Deflate => decompress::<DeflateDecoder<_>>(&body),
123                    Encoding::Identity => {
124                        let body = String::from_utf8_lossy(&body).to_string();
125                        Ok(body)
126                    }
127                }
128            } else {
129                Err(format!("Unknown Content-Type").into())
130            }
131        } else {
132            Err("Content-Type does not specify text".to_string().into())
133        }
134    } else {
135        Err("No Content-Type specified".to_string().into())
136    }
137}
138
139/// The different supported encoding types
140#[derive(Debug, Clone, PartialEq, Eq)]
141pub enum Encoding {
142    /// gzip
143    Gzip,
144
145    /// deflate
146    Deflate,
147
148    /// Identity
149    Identity,
150}
151
152impl Encoding {
153    fn as_header_value(&self) -> HeaderValue {
154        match &self {
155            Encoding::Gzip => HeaderValue::from_static("gzip"),
156
157            Encoding::Deflate => HeaderValue::from_static("deflate"),
158
159            Encoding::Identity => HeaderValue::from_static("identity"),
160        }
161    }
162
163    /// Retrieves the content encoding of a response.
164    /// If the encoding is not supported returns None
165    pub fn get_response_encoding(res: &Response<Incoming>) -> Option<Self> {
166        if let Some(content_encoding) = res
167            .headers()
168            .get(CONTENT_ENCODING)
169            .and_then(|v| v.to_str().ok())
170        {
171            match content_encoding {
172                "gzip" | "x-gzip" => Some(Encoding::Gzip),
173                "deflate" => Some(Encoding::Gzip),
174                "identitiy" => Some(Encoding::Identity),
175                _ => None,
176            }
177        } else {
178            Some(Encoding::Identity)
179        }
180    }
181}
182
183/// Creates a boxed body from a Byte like type
184pub fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
185    Full::new(chunk.into())
186        .map_err(|never| match never {})
187        .boxed()
188}
189
190impl FromStr for Encoding {
191    type Err = HyperContentEncodingError;
192    fn from_str(value: &str) -> std::result::Result<Self, Self::Err> {
193        match value {
194            "x-gzip" | "gzip" => Ok(Encoding::Gzip),
195            "deflate" => Ok(Encoding::Deflate),
196            "identity" => Ok(Encoding::Identity),
197            _ => Err(format!("Unrecognized encoding {}", value).into()),
198        }
199    }
200}
201
202/// Compresses a response with the desired compression algorithm.
203///
204/// Currently, only `gzip` and `deflate` are supported
205///
206/// This method will modify the `Content-Encoding` and `Content-Length` headers
207///
208/// TODO: encode the stream
209pub async fn encode_response(res: Res, content_encoding: Encoding) -> Result<Res> {
210    let headers = res.headers().clone();
211    let status = res.status();
212
213    let res = res.map(|b| b.boxed());
214
215    let body: Bytes = res
216        .into_body()
217        .collect()
218        .await
219        .map_err(convert_err)?
220        .to_bytes();
221
222    let mut ret_vec: Vec<u8> = Vec::new();
223    match content_encoding {
224        Encoding::Gzip => {
225            let body: &[u8] = &body;
226            GzEncoder::new(body, Compression::fast())
227                .read_to_end(&mut ret_vec)
228                .map_err(convert_err)
229        }
230
231        Encoding::Deflate => {
232            let body: &[u8] = &body;
233            DeflateEncoder::new(body, Compression::fast())
234                .read_to_end(&mut ret_vec)
235                .map_err(convert_err)
236        }
237
238        Encoding::Identity => {
239            ret_vec = body.into();
240            Ok(ret_vec.len())
241        }
242    }?;
243
244    let body: Bytes = ret_vec.into();
245    let body_len = body.len();
246
247    let mut res = Response::new(full(body));
248    *res.headers_mut() = headers;
249    *res.status_mut() = status;
250
251    res.headers_mut()
252        .insert(CONTENT_ENCODING, content_encoding.as_header_value());
253
254    res.headers_mut().insert(
255        CONTENT_LENGTH,
256        body_len
257            .to_string()
258            .parse()
259            .expect("Unexpected Content-Length"),
260    );
261
262    Ok(res)
263}
264
265/// A hyper service that compresses responses
266#[derive(Debug, Clone)]
267pub struct Compressor<S> {
268    inner: S,
269}
270
271impl<S> Compressor<S> {
272    /// Creates a new Compression middleware that uses `gzip`
273    pub fn new(inner: S) -> Self {
274        Compressor { inner }
275    }
276}
277
278/// Pareses a Quality Values
279/// https://datatracker.ietf.org/doc/html/rfc9110#quality.values
280///
281///  weight = OWS ";" OWS "q=" qvalue
282///  qvalue = ( "0" [ "." 0*3DIGIT ] )
283///         / ( "1" [ "." 0*3("0") ] )
284fn parse_weight(input: &str) -> Option<f32> {
285    let mut chars = input.chars().peekable();
286
287    // Parse leading optional white space
288    while let Some(ch) = chars.peek() {
289        if !ch.is_whitespace() {
290            break;
291        }
292
293        chars.next();
294    }
295
296    // Parse ";"
297    if chars.next() != Some(';') {
298        return None;
299    }
300
301    // Parse optional white space after ";"
302    while let Some(ch) = chars.peek() {
303        if !ch.is_whitespace() {
304            break;
305        }
306
307        chars.next();
308    }
309
310    // Parse "q="
311    if chars.next() != Some('q') || chars.next() != Some('=') {
312        return None;
313    }
314
315    // Parse qvalue
316    let mut qvalue_str = String::new();
317    for ch in &mut chars {
318        if !ch.is_ascii_digit() && ch != '.' {
319            break;
320        }
321        qvalue_str.push(ch);
322    }
323
324    // Parse qvalue into a float
325    let qvalue: f32 = qvalue_str.parse().ok()?;
326
327    if (0.0..1.0).contains(&qvalue) {
328        Some(qvalue)
329    } else {
330        eprintln!("Q={}", qvalue);
331        None
332    }
333}
334
335// Parses the prefered_encoding (only keeping the currently supported encodings)
336fn parse_encoding(accepted_encodings: &str) -> Vec<(Encoding, f32)> {
337    let mut accepted_encodings = accepted_encodings.trim();
338
339    let mut res = Vec::new();
340
341    let mut default_weight: Option<f32> = None;
342
343    loop {
344        for token in ["gzip", "deflate", "identity", "*"] {
345            if accepted_encodings.starts_with(token) {
346                (_, accepted_encodings) = accepted_encodings.split_at(token.len());
347
348                let mut weight: f32 = 1.0;
349                if let Some(res) = parse_weight(accepted_encodings) {
350                    weight = res;
351                }
352
353                if token == "*" {
354                    default_weight = Some(weight);
355                } else {
356                    res.push((Encoding::from_str(token).unwrap(), weight));
357                }
358
359                break;
360            }
361        }
362
363        if let Some(index) = accepted_encodings.find(',') {
364            (_, accepted_encodings) = accepted_encodings.split_at(index);
365            let mut chars = accepted_encodings.chars();
366            chars.next();
367            accepted_encodings = chars.as_str().trim();
368        } else {
369            break;
370        }
371    }
372
373    if let Some(weigth) = default_weight {
374        for encoding in [Encoding::Gzip, Encoding::Deflate, Encoding::Identity] {
375            if !res.iter().any(|(x, _)| *x == encoding) {
376                res.push((encoding, weigth));
377            }
378        }
379    } else if !res.iter().any(|(x, _)| *x == Encoding::Identity) {
380        res.push((Encoding::Identity, 1.0));
381    }
382
383    res
384
385    // Else fail with 415
386}
387
388fn prefered_encoding(accepted_encodings: &str) -> Option<Encoding> {
389    let mut encodings = parse_encoding(accepted_encodings);
390    encodings.sort_by_key(|&(_, w)| -(w * 1000.0) as i32);
391
392    encodings
393        .iter()
394        .find(|(_, w)| *w > 0.0)
395        .map(|(e, _)| e.to_owned())
396}
397
398type Req = Request<Incoming>;
399type Res = Response<BoxBody<Bytes, hyper::Error>>;
400
401impl<S> Service<Req> for Compressor<S>
402where
403    S: Service<Req, Response = Res>,
404    S::Future: 'static + Send,
405    S::Error: 'static + Send,
406    S::Error: Debug,
407    S::Response: 'static,
408{
409    type Response = Res;
410    type Error = Box<HyperContentEncodingError>;
411    type Future =
412        Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
413
414    fn call(&self, req: Req) -> Self::Future {
415        let headers = req.headers().clone();
416
417        // Gets the desired encoding
418        let encoding = if let Some(accepted_encodings) = headers.get(ACCEPT_ENCODING) {
419            if let Some(desired_encoding) =
420                accepted_encodings.to_str().ok().and_then(prefered_encoding)
421            {
422                desired_encoding
423            } else {
424                return Box::pin(async move {
425                    let mut res = Response::new(full("Unsuported requestedd encoding\n"));
426                    *res.status_mut() = StatusCode::NOT_ACCEPTABLE;
427                    Ok(res)
428                });
429            }
430        } else {
431            Encoding::Gzip
432        };
433
434        let fut = self.inner.call(req);
435
436        let f = async move {
437            match fut.await {
438                Ok(response) => encode_response(response, encoding).await.map_err(Box::new),
439                Err(e) => Err(Box::new(convert_err(e))),
440            }
441        };
442
443        Box::pin(f)
444    }
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450
451    #[test]
452    fn no_weight() {
453        let result = parse_encoding("compress, gzip");
454        assert_eq!(
455            vec![(Encoding::Gzip, 1.0), (Encoding::Identity, 1.0)],
456            result
457        );
458    }
459
460    #[test]
461    fn empty() {
462        let result = parse_encoding("");
463        assert_eq!(vec![(Encoding::Identity, 1.0)], result);
464    }
465
466    #[test]
467    fn star() {
468        let result = parse_encoding("*");
469        assert_eq!(
470            vec![
471                (Encoding::Gzip, 1.0),
472                (Encoding::Deflate, 1.0),
473                (Encoding::Identity, 1.0)
474            ],
475            result
476        );
477    }
478
479    #[test]
480    fn weigth() {
481        let result = parse_encoding("deflate;q=0.5, gzip;q=1.0");
482        eprintln!("{:?}", result);
483        assert_eq!(
484            vec![
485                (Encoding::Deflate, 0.5),
486                (Encoding::Gzip, 1.0),
487                (Encoding::Identity, 1.0)
488            ],
489            result
490        );
491    }
492
493    #[test]
494    fn no_identity() {
495        let result = parse_encoding("gzip;q=1.0, deflate; q=0.5, *;q=0");
496        eprintln!("{:?}", result);
497        assert_eq!(
498            vec![
499                (Encoding::Gzip, 1.0),
500                (Encoding::Deflate, 0.5),
501                (Encoding::Identity, 0.0)
502            ],
503            result
504        );
505    }
506}