1use crate::decorator::Decorator;
2use crate::encoding::Writer;
3use crate::handler::RequestHandler;
4use crate::{OptionReqBody, RequestContext, ResponseBody};
5use async_trait::async_trait;
6use bytes::{Buf, Bytes};
7use flate2::write::{GzEncoder, ZlibEncoder};
8use flate2::Compression;
9use http::{Response, StatusCode};
10use http_body::{Body, Frame};
11use http_body_util::combinators::UnsyncBoxBody;
12use micro_http::protocol::{HttpError, SendError};
13use pin_project_lite::pin_project;
14use std::fmt::Debug;
15use std::io;
16use std::io::Write;
17use std::pin::Pin;
18use std::task::{ready, Context, Poll};
19use tracing::{error, trace};
20use zstd::stream::write::Encoder as ZstdEncoder;
21pub(crate) enum Encoder {
25 Gzip(GzEncoder<Writer>),
27 Deflate(ZlibEncoder<Writer>),
29 Zstd(ZstdEncoder<'static, Writer>),
31 Br(Box<brotli::CompressorWriter<Writer>>),
33}
34
35impl Encoder {
36 fn gzip() -> Self {
38 Self::Gzip(GzEncoder::new(Writer::new(), Compression::best()))
39 }
40
41 fn deflate() -> Self {
43 Self::Deflate(ZlibEncoder::new(Writer::new(), Compression::best()))
44 }
45
46 fn zstd() -> Self {
48 Self::Zstd(ZstdEncoder::new(Writer::new(), 6).unwrap())
50 }
51
52 fn br() -> Self {
54 Self::Br(Box::new(brotli::CompressorWriter::new(
55 Writer::new(),
56 32 * 1024, 3, 22, )))
60 }
61
62 fn select(accept_encodings: &str) -> Option<Self> {
64 if accept_encodings.contains("zstd") {
65 Some(Self::zstd())
66 } else if accept_encodings.contains("br") {
67 Some(Self::br())
68 } else if accept_encodings.contains("gzip") {
69 Some(Self::gzip())
70 } else if accept_encodings.contains("deflate") {
71 Some(Self::deflate())
72 } else {
73 None
74 }
75 }
76
77 fn name(&self) -> &'static str {
79 match self {
80 Encoder::Gzip(_) => "gzip",
81 Encoder::Deflate(_) => "deflate",
82 Encoder::Zstd(_) => "zstd",
83 Encoder::Br(_) => "br",
84 }
85 }
86
87 fn write(&mut self, data: &[u8]) -> Result<(), io::Error> {
89 match self {
90 Self::Gzip(ref mut encoder) => match encoder.write_all(data) {
91 Ok(_) => Ok(()),
92 Err(err) => {
93 trace!("Error encoding gzip encoding: {}", err);
94 Err(err)
95 }
96 },
97
98 Self::Deflate(ref mut encoder) => match encoder.write_all(data) {
99 Ok(_) => Ok(()),
100 Err(err) => {
101 trace!("Error encoding deflate encoding: {}", err);
102 Err(err)
103 }
104 },
105
106 Self::Zstd(ref mut encoder) => match encoder.write_all(data) {
107 Ok(_) => Ok(()),
108 Err(err) => {
109 trace!("Error encoding zstd encoding: {}", err);
110 Err(err)
111 }
112 },
113
114 Self::Br(ref mut encoder) => match encoder.write_all(data) {
115 Ok(_) => Ok(()),
116 Err(err) => {
117 trace!("Error encoding br encoding: {}", err);
118 Err(err)
119 }
120 },
121 }
122 }
123
124 fn take(&mut self) -> Bytes {
126 match *self {
127 Self::Gzip(ref mut encoder) => encoder.get_mut().take(),
128 Self::Deflate(ref mut encoder) => encoder.get_mut().take(),
129 Self::Zstd(ref mut encoder) => encoder.get_mut().take(),
130 Self::Br(ref mut encoder) => encoder.get_mut().take(),
131 }
132 }
133
134 fn finish(self) -> Result<Bytes, io::Error> {
136 match self {
137 Self::Gzip(encoder) => match encoder.finish() {
138 Ok(writer) => Ok(writer.buf.freeze()),
139 Err(err) => Err(err),
140 },
141
142 Self::Deflate(encoder) => match encoder.finish() {
143 Ok(writer) => Ok(writer.buf.freeze()),
144 Err(err) => Err(err),
145 },
146
147 Self::Zstd(encoder) => match encoder.finish() {
148 Ok(writer) => Ok(writer.buf.freeze()),
149 Err(err) => Err(err),
150 },
151
152 Self::Br(mut encoder) => match encoder.flush() {
153 Ok(()) => Ok(encoder.into_inner().buf.freeze()),
154 Err(err) => Err(err),
155 },
156 }
157 }
158}
159
160pin_project! {
161 struct EncodedBody<B: Body> {
163 #[pin]
164 inner: B,
165 encoder: Option<Encoder>,
166 state: Option<bool>,
167 }
168}
169
170impl<B: Body> EncodedBody<B> {
171 fn new(b: B, encoder: Encoder) -> Self {
173 Self { inner: b, encoder: Some(encoder), state: Some(true) }
174 }
175}
176
177impl<B> Body for EncodedBody<B>
178where
179 B: Body + Unpin,
180 B::Data: Buf + Debug,
181 B::Error: ToString,
182{
183 type Data = Bytes;
184 type Error = HttpError;
185
186 fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
187 let mut this = self.project();
188
189 if this.state.is_none() {
190 return Poll::Ready(None);
191 }
192
193 loop {
194 return match ready!(this.inner.as_mut().poll_frame(cx)) {
195 Some(Ok(frame)) => {
196 let data = match frame.into_data() {
197 Ok(data) => data,
198 Err(mut frame) => {
199 let debug_info = frame.trailers_mut();
200 error!("want to data from body, but receive trailer header: {:?}", debug_info);
201 return Poll::Ready(Some(
202 Err(SendError::invalid_body(format!("invalid body frame : {:?}", debug_info)).into()),
203 ));
204 }
205 };
206
207 match this.encoder.as_mut().unwrap().write(data.chunk()) {
208 Ok(_) => (),
209 Err(e) => {
210 return Poll::Ready(Some(Err(SendError::from(e).into())));
211 }
212 }
213 let bytes = this.encoder.as_mut().unwrap().take();
215 if bytes.is_empty() {
216 continue;
217 }
218 Poll::Ready(Some(Ok(Frame::data(bytes))))
219 }
220 Some(Err(e)) => Poll::Ready(Some(Err(SendError::invalid_body(e.to_string()).into()))),
221 None => {
222 if this.state.is_some() {
223 this.state.take();
225
226 let bytes = match this.encoder.take().unwrap().finish() {
228 Ok(bytes) => bytes,
229 Err(e) => {
230 return Poll::Ready(Some(Err(SendError::from(e).into())));
231 }
232 };
233 if !bytes.is_empty() {
234 Poll::Ready(Some(Ok(Frame::data(bytes))))
235 } else {
236 Poll::Ready(None)
237 }
238 } else {
239 Poll::Ready(None)
240 }
241 }
242 };
243 }
244 }
245
246 fn is_end_stream(&self) -> bool {
247 self.inner.is_end_stream()
248 }
249}
250
251pub struct EncodeRequestHandler<H: RequestHandler> {
253 handler: H,
254}
255
256pub struct EncodeDecorator;
258
259impl<H: RequestHandler> Decorator<H> for EncodeDecorator {
260 type Out = EncodeRequestHandler<H>;
261
262 fn decorate(&self, raw: H) -> Self::Out {
263 EncodeRequestHandler { handler: raw }
264 }
265}
266
267#[async_trait]
268impl<H: RequestHandler> RequestHandler for EncodeRequestHandler<H> {
269 async fn invoke<'server, 'req>(&self, req: &mut RequestContext<'server, 'req>, req_body: OptionReqBody) -> Response<ResponseBody> {
270 let mut resp = self.handler.invoke(req, req_body).await;
271 encode(req, &mut resp);
272 resp
273 }
274}
275
276fn encode(req: &RequestContext, resp: &mut Response<ResponseBody>) {
278 let status_code = resp.status();
279 if status_code == StatusCode::NO_CONTENT || status_code == StatusCode::SWITCHING_PROTOCOLS {
280 return;
281 }
282
283 if req.headers().contains_key(http::header::CONTENT_ENCODING) {
285 return;
286 }
287
288 let possible_encodings = req.headers().get(http::header::ACCEPT_ENCODING);
290 if possible_encodings.is_none() {
291 return;
292 }
293
294 let accept_encodings = match possible_encodings.unwrap().to_str() {
296 Ok(s) => s,
297 Err(_) => {
298 return;
299 }
300 };
301
302 let encoder = match Encoder::select(accept_encodings) {
303 Some(encoder) => encoder,
304 None => {
305 return;
306 }
307 };
308
309 let body = resp.body_mut();
310
311 if body.is_empty() {
312 return;
313 }
314
315 match body.size_hint().upper() {
316 Some(upper) if upper <= 1024 => {
317 return;
319 }
320 _ => (),
321 }
322
323 let encoder_name = encoder.name();
324 let encoded_body = EncodedBody::new(body.take(), encoder);
325 body.replace(ResponseBody::stream(UnsyncBoxBody::new(encoded_body)));
326
327 resp.headers_mut().remove(http::header::CONTENT_LENGTH);
328 resp.headers_mut().append(http::header::CONTENT_ENCODING, encoder_name.parse().unwrap());
329}