micro_web/encoding/
encoder.rs

1use crate::decorator::Decorator;
2use crate::encoding::Writer;
3use crate::handler::RequestHandler;
4use crate::{OptionReqBody, RequestContext, ResponseBody};
5use async_trait::async_trait;
6use bytes::{Buf, Bytes};
7use flate2::write::{GzEncoder, ZlibEncoder};
8use flate2::Compression;
9use http::{Response, StatusCode};
10use http_body::{Body, Frame};
11use http_body_util::combinators::UnsyncBoxBody;
12use micro_http::protocol::{HttpError, SendError};
13use pin_project_lite::pin_project;
14use std::fmt::Debug;
15use std::io;
16use std::io::Write;
17use std::pin::Pin;
18use std::task::{ready, Context, Poll};
19use tracing::{error, trace};
20use zstd::stream::write::Encoder as ZstdEncoder;
21// (almost thanks and) copy from actix-http: https://github.com/actix/actix-web/blob/master/actix-http/src/encoding/encoder.rs
22
23/// Represents different types of content encoding.
24pub(crate) enum Encoder {
25    /// Gzip encoding.
26    Gzip(GzEncoder<Writer>),
27    /// Deflate encoding.
28    Deflate(ZlibEncoder<Writer>),
29    /// Zstd encoding.
30    Zstd(ZstdEncoder<'static, Writer>),
31    /// Brotli encoding.
32    Br(Box<brotli::CompressorWriter<Writer>>),
33}
34
35impl Encoder {
36    /// Creates a new Gzip encoder.
37    fn gzip() -> Self {
38        Self::Gzip(GzEncoder::new(Writer::new(), Compression::best()))
39    }
40
41    /// Creates a new Deflate encoder.
42    fn deflate() -> Self {
43        Self::Deflate(ZlibEncoder::new(Writer::new(), Compression::best()))
44    }
45
46    /// Creates a new Zstd encoder.
47    fn zstd() -> Self {
48        // todo: remove the unwrap
49        Self::Zstd(ZstdEncoder::new(Writer::new(), 6).unwrap())
50    }
51
52    /// Creates a new Brotli encoder.
53    fn br() -> Self {
54        Self::Br(Box::new(brotli::CompressorWriter::new(
55            Writer::new(),
56            32 * 1024, // 32 KiB buffer
57            3,         // BROTLI_PARAM_QUALITY
58            22,        // BROTLI_PARAM_LGWIN
59        )))
60    }
61
62    /// Selects an encoder based on the `Accept-Encoding` header.
63    fn select(accept_encodings: &str) -> Option<Self> {
64        if accept_encodings.contains("zstd") {
65            Some(Self::zstd())
66        } else if accept_encodings.contains("br") {
67            Some(Self::br())
68        } else if accept_encodings.contains("gzip") {
69            Some(Self::gzip())
70        } else if accept_encodings.contains("deflate") {
71            Some(Self::deflate())
72        } else {
73            None
74        }
75    }
76
77    /// Returns the name of the encoding.
78    fn name(&self) -> &'static str {
79        match self {
80            Encoder::Gzip(_) => "gzip",
81            Encoder::Deflate(_) => "deflate",
82            Encoder::Zstd(_) => "zstd",
83            Encoder::Br(_) => "br",
84        }
85    }
86
87    /// Writes data to the encoder.
88    fn write(&mut self, data: &[u8]) -> Result<(), io::Error> {
89        match self {
90            Self::Gzip(ref mut encoder) => match encoder.write_all(data) {
91                Ok(_) => Ok(()),
92                Err(err) => {
93                    trace!("Error encoding gzip encoding: {}", err);
94                    Err(err)
95                }
96            },
97
98            Self::Deflate(ref mut encoder) => match encoder.write_all(data) {
99                Ok(_) => Ok(()),
100                Err(err) => {
101                    trace!("Error encoding deflate encoding: {}", err);
102                    Err(err)
103                }
104            },
105
106            Self::Zstd(ref mut encoder) => match encoder.write_all(data) {
107                Ok(_) => Ok(()),
108                Err(err) => {
109                    trace!("Error encoding zstd encoding: {}", err);
110                    Err(err)
111                }
112            },
113
114            Self::Br(ref mut encoder) => match encoder.write_all(data) {
115                Ok(_) => Ok(()),
116                Err(err) => {
117                    trace!("Error encoding br encoding: {}", err);
118                    Err(err)
119                }
120            },
121        }
122    }
123
124    /// Takes the encoded data from the encoder.
125    fn take(&mut self) -> Bytes {
126        match *self {
127            Self::Gzip(ref mut encoder) => encoder.get_mut().take(),
128            Self::Deflate(ref mut encoder) => encoder.get_mut().take(),
129            Self::Zstd(ref mut encoder) => encoder.get_mut().take(),
130            Self::Br(ref mut encoder) => encoder.get_mut().take(),
131        }
132    }
133
134    /// Finishes the encoding process and returns the encoded data.
135    fn finish(self) -> Result<Bytes, io::Error> {
136        match self {
137            Self::Gzip(encoder) => match encoder.finish() {
138                Ok(writer) => Ok(writer.buf.freeze()),
139                Err(err) => Err(err),
140            },
141
142            Self::Deflate(encoder) => match encoder.finish() {
143                Ok(writer) => Ok(writer.buf.freeze()),
144                Err(err) => Err(err),
145            },
146
147            Self::Zstd(encoder) => match encoder.finish() {
148                Ok(writer) => Ok(writer.buf.freeze()),
149                Err(err) => Err(err),
150            },
151
152            Self::Br(mut encoder) => match encoder.flush() {
153                Ok(()) => Ok(encoder.into_inner().buf.freeze()),
154                Err(err) => Err(err),
155            },
156        }
157    }
158}
159
160pin_project! {
161    /// A wrapper around a `Body` that encodes the data.
162    struct EncodedBody<B: Body> {
163        #[pin]
164        inner: B,
165        encoder: Option<Encoder>,
166        state: Option<bool>,
167    }
168}
169
170impl<B: Body> EncodedBody<B> {
171    /// Creates a new `EncodedBody`.
172    fn new(b: B, encoder: Encoder) -> Self {
173        Self { inner: b, encoder: Some(encoder), state: Some(true) }
174    }
175}
176
177impl<B> Body for EncodedBody<B>
178where
179    B: Body + Unpin,
180    B::Data: Buf + Debug,
181    B::Error: ToString,
182{
183    type Data = Bytes;
184    type Error = HttpError;
185
186    fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
187        let mut this = self.project();
188
189        if this.state.is_none() {
190            return Poll::Ready(None);
191        }
192
193        loop {
194            return match ready!(this.inner.as_mut().poll_frame(cx)) {
195                Some(Ok(frame)) => {
196                    let data = match frame.into_data() {
197                        Ok(data) => data,
198                        Err(mut frame) => {
199                            let debug_info = frame.trailers_mut();
200                            error!("want to data from body, but receive trailer header: {:?}", debug_info);
201                            return Poll::Ready(Some(
202                                Err(SendError::invalid_body(format!("invalid body frame : {:?}", debug_info)).into()),
203                            ));
204                        }
205                    };
206
207                    match this.encoder.as_mut().unwrap().write(data.chunk()) {
208                        Ok(_) => (),
209                        Err(e) => {
210                            return Poll::Ready(Some(Err(SendError::from(e).into())));
211                        }
212                    }
213                    // use wrap here is safe, because we only take it when receive None
214                    let bytes = this.encoder.as_mut().unwrap().take();
215                    if bytes.is_empty() {
216                        continue;
217                    }
218                    Poll::Ready(Some(Ok(Frame::data(bytes))))
219                }
220                Some(Err(e)) => Poll::Ready(Some(Err(SendError::invalid_body(e.to_string()).into()))),
221                None => {
222                    if this.state.is_some() {
223                        // will only run below  code once
224                        this.state.take();
225
226                        // unwrap here is safe, because we only take once
227                        let bytes = match this.encoder.take().unwrap().finish() {
228                            Ok(bytes) => bytes,
229                            Err(e) => {
230                                return Poll::Ready(Some(Err(SendError::from(e).into())));
231                            }
232                        };
233                        if !bytes.is_empty() {
234                            Poll::Ready(Some(Ok(Frame::data(bytes))))
235                        } else {
236                            Poll::Ready(None)
237                        }
238                    } else {
239                        Poll::Ready(None)
240                    }
241                }
242            };
243        }
244    }
245
246    fn is_end_stream(&self) -> bool {
247        self.inner.is_end_stream()
248    }
249}
250
251/// A request handler that encodes the response body.
252pub struct EncodeRequestHandler<H: RequestHandler> {
253    handler: H,
254}
255
256/// A wrapper that creates `EncodeRequestHandler`.
257pub struct EncodeDecorator;
258
259impl<H: RequestHandler> Decorator<H> for EncodeDecorator {
260    type Out = EncodeRequestHandler<H>;
261
262    fn decorate(&self, raw: H) -> Self::Out {
263        EncodeRequestHandler { handler: raw }
264    }
265}
266
267#[async_trait]
268impl<H: RequestHandler> RequestHandler for EncodeRequestHandler<H> {
269    async fn invoke<'server, 'req>(&self, req: &mut RequestContext<'server, 'req>, req_body: OptionReqBody) -> Response<ResponseBody> {
270        let mut resp = self.handler.invoke(req, req_body).await;
271        encode(req, &mut resp);
272        resp
273    }
274}
275
276/// Encodes the response body based on the `Accept-Encoding` header.
277fn encode(req: &RequestContext, resp: &mut Response<ResponseBody>) {
278    let status_code = resp.status();
279    if status_code == StatusCode::NO_CONTENT || status_code == StatusCode::SWITCHING_PROTOCOLS {
280        return;
281    }
282
283    // response has already encoded
284    if req.headers().contains_key(http::header::CONTENT_ENCODING) {
285        return;
286    }
287
288    // request doesn't have any accept encodings
289    let possible_encodings = req.headers().get(http::header::ACCEPT_ENCODING);
290    if possible_encodings.is_none() {
291        return;
292    }
293
294    // here using unwrap is safe because we has checked
295    let accept_encodings = match possible_encodings.unwrap().to_str() {
296        Ok(s) => s,
297        Err(_) => {
298            return;
299        }
300    };
301
302    let encoder = match Encoder::select(accept_encodings) {
303        Some(encoder) => encoder,
304        None => {
305            return;
306        }
307    };
308
309    let body = resp.body_mut();
310
311    if body.is_empty() {
312        return;
313    }
314
315    match body.size_hint().upper() {
316        Some(upper) if upper <= 1024 => {
317            // less then 1k, we needn't compress
318            return;
319        }
320        _ => (),
321    }
322
323    let encoder_name = encoder.name();
324    let encoded_body = EncodedBody::new(body.take(), encoder);
325    body.replace(ResponseBody::stream(UnsyncBoxBody::new(encoded_body)));
326
327    resp.headers_mut().remove(http::header::CONTENT_LENGTH);
328    resp.headers_mut().append(http::header::CONTENT_ENCODING, encoder_name.parse().unwrap());
329}