1use crate::encoding::Writer;
2use crate::handler::RequestHandler;
3use crate::handler::handler_decorator::HandlerDecorator;
4use crate::handler::handler_decorator_factory::HandlerDecoratorFactory;
5use crate::{OptionReqBody, RequestContext, ResponseBody};
6use async_trait::async_trait;
7use bytes::{Buf, Bytes};
8use flate2::Compression;
9use flate2::write::{GzEncoder, ZlibEncoder};
10use http::{Response, StatusCode};
11use http_body::{Body, Frame};
12use http_body_util::combinators::UnsyncBoxBody;
13use micro_http::protocol::{HttpError, SendError};
14use pin_project_lite::pin_project;
15use std::fmt::Debug;
16use std::io;
17use std::io::Write;
18use std::pin::Pin;
19use std::task::{Context, Poll, ready};
20use tracing::{error, trace};
21use zstd::stream::write::Encoder as ZstdEncoder;
22pub(crate) enum Encoder {
26 Gzip(GzEncoder<Writer>),
28 Deflate(ZlibEncoder<Writer>),
30 Zstd(ZstdEncoder<'static, Writer>),
32 Br(Box<brotli::CompressorWriter<Writer>>),
34}
35
36impl Encoder {
37 fn gzip() -> Self {
39 Self::Gzip(GzEncoder::new(Writer::new(), Compression::best()))
40 }
41
42 fn deflate() -> Self {
44 Self::Deflate(ZlibEncoder::new(Writer::new(), Compression::best()))
45 }
46
47 fn zstd() -> Self {
49 Self::Zstd(ZstdEncoder::new(Writer::new(), 6).unwrap())
51 }
52
53 fn br() -> Self {
55 Self::Br(Box::new(brotli::CompressorWriter::new(
56 Writer::new(),
57 32 * 1024, 3, 22, )))
61 }
62
63 fn select(accept_encodings: &str) -> Option<Self> {
65 if accept_encodings.contains("zstd") {
66 Some(Self::zstd())
67 } else if accept_encodings.contains("br") {
68 Some(Self::br())
69 } else if accept_encodings.contains("gzip") {
70 Some(Self::gzip())
71 } else if accept_encodings.contains("deflate") {
72 Some(Self::deflate())
73 } else {
74 None
75 }
76 }
77
78 fn name(&self) -> &'static str {
80 match self {
81 Encoder::Gzip(_) => "gzip",
82 Encoder::Deflate(_) => "deflate",
83 Encoder::Zstd(_) => "zstd",
84 Encoder::Br(_) => "br",
85 }
86 }
87
88 fn write(&mut self, data: &[u8]) -> Result<(), io::Error> {
90 match self {
91 Self::Gzip(encoder) => match encoder.write_all(data) {
92 Ok(_) => Ok(()),
93 Err(err) => {
94 trace!("Error encoding gzip encoding: {}", err);
95 Err(err)
96 }
97 },
98
99 Self::Deflate(encoder) => match encoder.write_all(data) {
100 Ok(_) => Ok(()),
101 Err(err) => {
102 trace!("Error encoding deflate encoding: {}", err);
103 Err(err)
104 }
105 },
106
107 Self::Zstd(encoder) => match encoder.write_all(data) {
108 Ok(_) => Ok(()),
109 Err(err) => {
110 trace!("Error encoding zstd encoding: {}", err);
111 Err(err)
112 }
113 },
114
115 Self::Br(encoder) => match encoder.write_all(data) {
116 Ok(_) => Ok(()),
117 Err(err) => {
118 trace!("Error encoding br encoding: {}", err);
119 Err(err)
120 }
121 },
122 }
123 }
124
125 fn take(&mut self) -> Bytes {
127 match self {
128 Self::Gzip(encoder) => encoder.get_mut().take(),
129 Self::Deflate(encoder) => encoder.get_mut().take(),
130 Self::Zstd(encoder) => encoder.get_mut().take(),
131 Self::Br(encoder) => encoder.get_mut().take(),
132 }
133 }
134
135 fn finish(self) -> Result<Bytes, io::Error> {
137 match self {
138 Self::Gzip(encoder) => match encoder.finish() {
139 Ok(writer) => Ok(writer.buf.freeze()),
140 Err(err) => Err(err),
141 },
142
143 Self::Deflate(encoder) => match encoder.finish() {
144 Ok(writer) => Ok(writer.buf.freeze()),
145 Err(err) => Err(err),
146 },
147
148 Self::Zstd(encoder) => match encoder.finish() {
149 Ok(writer) => Ok(writer.buf.freeze()),
150 Err(err) => Err(err),
151 },
152
153 Self::Br(mut encoder) => match encoder.flush() {
154 Ok(()) => Ok(encoder.into_inner().buf.freeze()),
155 Err(err) => Err(err),
156 },
157 }
158 }
159}
160
161pin_project! {
162 struct EncodedBody<B: Body> {
164 #[pin]
165 inner: B,
166 encoder: Option<Encoder>,
167 state: Option<bool>,
168 }
169}
170
171impl<B: Body> EncodedBody<B> {
172 fn new(b: B, encoder: Encoder) -> Self {
174 Self { inner: b, encoder: Some(encoder), state: Some(true) }
175 }
176}
177
178impl<B> Body for EncodedBody<B>
179where
180 B: Body + Unpin,
181 B::Data: Buf + Debug,
182 B::Error: ToString,
183{
184 type Data = Bytes;
185 type Error = HttpError;
186
187 fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
188 let mut this = self.project();
189
190 if this.state.is_none() {
191 return Poll::Ready(None);
192 }
193
194 loop {
195 return match ready!(this.inner.as_mut().poll_frame(cx)) {
196 Some(Ok(frame)) => {
197 let data = match frame.into_data() {
198 Ok(data) => data,
199 Err(mut frame) => {
200 let debug_info = frame.trailers_mut();
201 error!("want to data from body, but receive trailer header: {:?}", debug_info);
202 return Poll::Ready(Some(
203 Err(SendError::invalid_body(format!("invalid body frame : {:?}", debug_info)).into()),
204 ));
205 }
206 };
207
208 match this.encoder.as_mut().unwrap().write(data.chunk()) {
209 Ok(_) => (),
210 Err(e) => {
211 return Poll::Ready(Some(Err(SendError::from(e).into())));
212 }
213 }
214 let bytes = this.encoder.as_mut().unwrap().take();
216 if bytes.is_empty() {
217 continue;
218 }
219 Poll::Ready(Some(Ok(Frame::data(bytes))))
220 }
221 Some(Err(e)) => Poll::Ready(Some(Err(SendError::invalid_body(e.to_string()).into()))),
222 None => {
223 if this.state.is_some() {
224 this.state.take();
226
227 let bytes = match this.encoder.take().unwrap().finish() {
229 Ok(bytes) => bytes,
230 Err(e) => {
231 return Poll::Ready(Some(Err(SendError::from(e).into())));
232 }
233 };
234 if !bytes.is_empty() { Poll::Ready(Some(Ok(Frame::data(bytes)))) } else { Poll::Ready(None) }
235 } else {
236 Poll::Ready(None)
237 }
238 }
239 };
240 }
241 }
242
243 fn is_end_stream(&self) -> bool {
244 self.inner.is_end_stream()
245 }
246}
247
248#[derive(Debug)]
250pub struct EncodeRequestHandler<H: RequestHandler> {
251 handler: H,
252}
253
254#[derive(Debug)]
256pub struct EncodeDecorator;
257
258impl<H: RequestHandler> HandlerDecorator<H> for EncodeDecorator {
259 type Output = EncodeRequestHandler<H>;
260
261 fn decorate(&self, raw: H) -> Self::Output {
262 EncodeRequestHandler { handler: raw }
263 }
264}
265
266impl HandlerDecoratorFactory for EncodeDecorator {
267 type Output<In>
268 = EncodeDecorator
269 where
270 In: RequestHandler;
271
272 fn create_decorator<In>(&self) -> Self::Output<In>
273 where
274 In: RequestHandler,
275 {
276 EncodeDecorator
277 }
278}
279
280#[async_trait]
281impl<H: RequestHandler> RequestHandler for EncodeRequestHandler<H> {
282 async fn invoke<'server, 'req>(&self, req: &mut RequestContext<'server, 'req>, req_body: OptionReqBody) -> Response<ResponseBody> {
283 let mut resp = self.handler.invoke(req, req_body).await;
284 encode(req, &mut resp);
285 resp
286 }
287}
288
289fn encode(req: &RequestContext, resp: &mut Response<ResponseBody>) {
291 let status_code = resp.status();
292 if status_code == StatusCode::NO_CONTENT || status_code == StatusCode::SWITCHING_PROTOCOLS {
293 return;
294 }
295
296 if req.headers().contains_key(http::header::CONTENT_ENCODING) {
298 return;
299 }
300
301 let possible_encodings = req.headers().get(http::header::ACCEPT_ENCODING);
303 if possible_encodings.is_none() {
304 return;
305 }
306
307 let accept_encodings = match possible_encodings.unwrap().to_str() {
309 Ok(s) => s,
310 Err(_) => {
311 return;
312 }
313 };
314
315 let encoder = match Encoder::select(accept_encodings) {
316 Some(encoder) => encoder,
317 None => {
318 return;
319 }
320 };
321
322 let body = resp.body_mut();
323
324 if body.is_empty() {
325 return;
326 }
327
328 match body.size_hint().upper() {
329 Some(upper) if upper <= 1024 => {
330 return;
332 }
333 _ => (),
334 }
335
336 let encoder_name = encoder.name();
337 let encoded_body = EncodedBody::new(body.take(), encoder);
338 body.replace(ResponseBody::stream(UnsyncBoxBody::new(encoded_body)));
339
340 resp.headers_mut().remove(http::header::CONTENT_LENGTH);
341 resp.headers_mut().append(http::header::CONTENT_ENCODING, encoder_name.parse().unwrap());
342}