micro_web/encoding/
encoder.rs

1use crate::encoding::Writer;
2use crate::handler::RequestHandler;
3use crate::handler::handler_decorator::HandlerDecorator;
4use crate::handler::handler_decorator_factory::HandlerDecoratorFactory;
5use crate::{OptionReqBody, RequestContext, ResponseBody};
6use async_trait::async_trait;
7use bytes::{Buf, Bytes};
8use flate2::Compression;
9use flate2::write::{GzEncoder, ZlibEncoder};
10use http::{Response, StatusCode};
11use http_body::{Body, Frame};
12use http_body_util::combinators::UnsyncBoxBody;
13use micro_http::protocol::{HttpError, SendError};
14use pin_project_lite::pin_project;
15use std::fmt::Debug;
16use std::io;
17use std::io::Write;
18use std::pin::Pin;
19use std::task::{Context, Poll, ready};
20use tracing::{error, trace};
21use zstd::stream::write::Encoder as ZstdEncoder;
22// (almost thanks and) copy from actix-http: https://github.com/actix/actix-web/blob/master/actix-http/src/encoding/encoder.rs
23
24/// Represents different types of content encoding.
25pub(crate) enum Encoder {
26    /// Gzip encoding.
27    Gzip(GzEncoder<Writer>),
28    /// Deflate encoding.
29    Deflate(ZlibEncoder<Writer>),
30    /// Zstd encoding.
31    Zstd(ZstdEncoder<'static, Writer>),
32    /// Brotli encoding.
33    Br(Box<brotli::CompressorWriter<Writer>>),
34}
35
36impl Encoder {
37    /// Creates a new Gzip encoder.
38    fn gzip() -> Self {
39        Self::Gzip(GzEncoder::new(Writer::new(), Compression::best()))
40    }
41
42    /// Creates a new Deflate encoder.
43    fn deflate() -> Self {
44        Self::Deflate(ZlibEncoder::new(Writer::new(), Compression::best()))
45    }
46
47    /// Creates a new Zstd encoder.
48    fn zstd() -> Self {
49        // todo: remove the unwrap
50        Self::Zstd(ZstdEncoder::new(Writer::new(), 6).unwrap())
51    }
52
53    /// Creates a new Brotli encoder.
54    fn br() -> Self {
55        Self::Br(Box::new(brotli::CompressorWriter::new(
56            Writer::new(),
57            32 * 1024, // 32 KiB buffer
58            3,         // BROTLI_PARAM_QUALITY
59            22,        // BROTLI_PARAM_LGWIN
60        )))
61    }
62
63    /// Selects an encoder based on the `Accept-Encoding` header.
64    fn select(accept_encodings: &str) -> Option<Self> {
65        if accept_encodings.contains("zstd") {
66            Some(Self::zstd())
67        } else if accept_encodings.contains("br") {
68            Some(Self::br())
69        } else if accept_encodings.contains("gzip") {
70            Some(Self::gzip())
71        } else if accept_encodings.contains("deflate") {
72            Some(Self::deflate())
73        } else {
74            None
75        }
76    }
77
78    /// Returns the name of the encoding.
79    fn name(&self) -> &'static str {
80        match self {
81            Encoder::Gzip(_) => "gzip",
82            Encoder::Deflate(_) => "deflate",
83            Encoder::Zstd(_) => "zstd",
84            Encoder::Br(_) => "br",
85        }
86    }
87
88    /// Writes data to the encoder.
89    fn write(&mut self, data: &[u8]) -> Result<(), io::Error> {
90        match self {
91            Self::Gzip(encoder) => match encoder.write_all(data) {
92                Ok(_) => Ok(()),
93                Err(err) => {
94                    trace!("Error encoding gzip encoding: {}", err);
95                    Err(err)
96                }
97            },
98
99            Self::Deflate(encoder) => match encoder.write_all(data) {
100                Ok(_) => Ok(()),
101                Err(err) => {
102                    trace!("Error encoding deflate encoding: {}", err);
103                    Err(err)
104                }
105            },
106
107            Self::Zstd(encoder) => match encoder.write_all(data) {
108                Ok(_) => Ok(()),
109                Err(err) => {
110                    trace!("Error encoding zstd encoding: {}", err);
111                    Err(err)
112                }
113            },
114
115            Self::Br(encoder) => match encoder.write_all(data) {
116                Ok(_) => Ok(()),
117                Err(err) => {
118                    trace!("Error encoding br encoding: {}", err);
119                    Err(err)
120                }
121            },
122        }
123    }
124
125    /// Takes the encoded data from the encoder.
126    fn take(&mut self) -> Bytes {
127        match self {
128            Self::Gzip(encoder) => encoder.get_mut().take(),
129            Self::Deflate(encoder) => encoder.get_mut().take(),
130            Self::Zstd(encoder) => encoder.get_mut().take(),
131            Self::Br(encoder) => encoder.get_mut().take(),
132        }
133    }
134
135    /// Finishes the encoding process and returns the encoded data.
136    fn finish(self) -> Result<Bytes, io::Error> {
137        match self {
138            Self::Gzip(encoder) => match encoder.finish() {
139                Ok(writer) => Ok(writer.buf.freeze()),
140                Err(err) => Err(err),
141            },
142
143            Self::Deflate(encoder) => match encoder.finish() {
144                Ok(writer) => Ok(writer.buf.freeze()),
145                Err(err) => Err(err),
146            },
147
148            Self::Zstd(encoder) => match encoder.finish() {
149                Ok(writer) => Ok(writer.buf.freeze()),
150                Err(err) => Err(err),
151            },
152
153            Self::Br(mut encoder) => match encoder.flush() {
154                Ok(()) => Ok(encoder.into_inner().buf.freeze()),
155                Err(err) => Err(err),
156            },
157        }
158    }
159}
160
161pin_project! {
162    /// A wrapper around a `Body` that encodes the data.
163    struct EncodedBody<B: Body> {
164        #[pin]
165        inner: B,
166        encoder: Option<Encoder>,
167        state: Option<bool>,
168    }
169}
170
171impl<B: Body> EncodedBody<B> {
172    /// Creates a new `EncodedBody`.
173    fn new(b: B, encoder: Encoder) -> Self {
174        Self { inner: b, encoder: Some(encoder), state: Some(true) }
175    }
176}
177
178impl<B> Body for EncodedBody<B>
179where
180    B: Body + Unpin,
181    B::Data: Buf + Debug,
182    B::Error: ToString,
183{
184    type Data = Bytes;
185    type Error = HttpError;
186
187    fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
188        let mut this = self.project();
189
190        if this.state.is_none() {
191            return Poll::Ready(None);
192        }
193
194        loop {
195            return match ready!(this.inner.as_mut().poll_frame(cx)) {
196                Some(Ok(frame)) => {
197                    let data = match frame.into_data() {
198                        Ok(data) => data,
199                        Err(mut frame) => {
200                            let debug_info = frame.trailers_mut();
201                            error!("want to data from body, but receive trailer header: {:?}", debug_info);
202                            return Poll::Ready(Some(
203                                Err(SendError::invalid_body(format!("invalid body frame : {:?}", debug_info)).into()),
204                            ));
205                        }
206                    };
207
208                    match this.encoder.as_mut().unwrap().write(data.chunk()) {
209                        Ok(_) => (),
210                        Err(e) => {
211                            return Poll::Ready(Some(Err(SendError::from(e).into())));
212                        }
213                    }
214                    // use wrap here is safe, because we only take it when receive None
215                    let bytes = this.encoder.as_mut().unwrap().take();
216                    if bytes.is_empty() {
217                        continue;
218                    }
219                    Poll::Ready(Some(Ok(Frame::data(bytes))))
220                }
221                Some(Err(e)) => Poll::Ready(Some(Err(SendError::invalid_body(e.to_string()).into()))),
222                None => {
223                    if this.state.is_some() {
224                        // will only run below  code once
225                        this.state.take();
226
227                        // unwrap here is safe, because we only take once
228                        let bytes = match this.encoder.take().unwrap().finish() {
229                            Ok(bytes) => bytes,
230                            Err(e) => {
231                                return Poll::Ready(Some(Err(SendError::from(e).into())));
232                            }
233                        };
234                        if !bytes.is_empty() { Poll::Ready(Some(Ok(Frame::data(bytes)))) } else { Poll::Ready(None) }
235                    } else {
236                        Poll::Ready(None)
237                    }
238                }
239            };
240        }
241    }
242
243    fn is_end_stream(&self) -> bool {
244        self.inner.is_end_stream()
245    }
246}
247
248/// A request handler that encodes the response body.
249#[derive(Debug)]
250pub struct EncodeRequestHandler<H: RequestHandler> {
251    handler: H,
252}
253
254/// A wrapper that creates `EncodeRequestHandler`.
255#[derive(Debug)]
256pub struct EncodeDecorator;
257
258impl<H: RequestHandler> HandlerDecorator<H> for EncodeDecorator {
259    type Output = EncodeRequestHandler<H>;
260
261    fn decorate(&self, raw: H) -> Self::Output {
262        EncodeRequestHandler { handler: raw }
263    }
264}
265
266impl HandlerDecoratorFactory for EncodeDecorator {
267    type Output<In>
268        = EncodeDecorator
269    where
270        In: RequestHandler;
271
272    fn create_decorator<In>(&self) -> Self::Output<In>
273    where
274        In: RequestHandler,
275    {
276        EncodeDecorator
277    }
278}
279
280#[async_trait]
281impl<H: RequestHandler> RequestHandler for EncodeRequestHandler<H> {
282    async fn invoke<'server, 'req>(&self, req: &mut RequestContext<'server, 'req>, req_body: OptionReqBody) -> Response<ResponseBody> {
283        let mut resp = self.handler.invoke(req, req_body).await;
284        encode(req, &mut resp);
285        resp
286    }
287}
288
289/// Encodes the response body based on the `Accept-Encoding` header.
290fn encode(req: &RequestContext, resp: &mut Response<ResponseBody>) {
291    let status_code = resp.status();
292    if status_code == StatusCode::NO_CONTENT || status_code == StatusCode::SWITCHING_PROTOCOLS {
293        return;
294    }
295
296    // response has already encoded
297    if req.headers().contains_key(http::header::CONTENT_ENCODING) {
298        return;
299    }
300
301    // request doesn't have any accept encodings
302    let possible_encodings = req.headers().get(http::header::ACCEPT_ENCODING);
303    if possible_encodings.is_none() {
304        return;
305    }
306
307    // here using unwrap is safe because we have checked
308    let accept_encodings = match possible_encodings.unwrap().to_str() {
309        Ok(s) => s,
310        Err(_) => {
311            return;
312        }
313    };
314
315    let encoder = match Encoder::select(accept_encodings) {
316        Some(encoder) => encoder,
317        None => {
318            return;
319        }
320    };
321
322    let body = resp.body_mut();
323
324    if body.is_empty() {
325        return;
326    }
327
328    match body.size_hint().upper() {
329        Some(upper) if upper <= 1024 => {
330            // less than 1k, we needn't compress
331            return;
332        }
333        _ => (),
334    }
335
336    let encoder_name = encoder.name();
337    let encoded_body = EncodedBody::new(body.take(), encoder);
338    body.replace(ResponseBody::stream(UnsyncBoxBody::new(encoded_body)));
339
340    resp.headers_mut().remove(http::header::CONTENT_LENGTH);
341    resp.headers_mut().append(http::header::CONTENT_ENCODING, encoder_name.parse().unwrap());
342}