http_response_compression/
body.rs1use crate::codec::Codec;
2use bytes::{Buf, Bytes, BytesMut};
3use compression_codecs::EncodeV2;
4use compression_core::util::{PartialBuffer, WriteBuffer};
5use http_body::{Body, Frame};
6use pin_project_lite::pin_project;
7use std::io;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10
11const OUTPUT_BUFFER_SIZE: usize = 8 * 1024; pin_project! {
14 #[project = CompressionBodyProj]
19 #[allow(missing_docs)]
20 pub enum CompressionBody<B> {
21 Compressed {
23 #[pin]
24 inner: B,
25 state: CompressedBody,
26 },
27 Passthrough {
29 #[pin]
30 inner: B,
31 },
32 }
33}
34
35pub(crate) struct CompressedBody {
37 encoder: Box<dyn EncodeV2 + Send>,
38 output_buffer: Vec<u8>,
39 always_flush: bool,
40 state: CompressState,
41 pending_trailers: Option<http::HeaderMap>,
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub(crate) enum CompressState {
47 Reading,
49 Finishing,
51 Trailers,
53 Done,
55}
56
57impl CompressedBody {
58 fn new(codec: Codec, always_flush: bool) -> Self {
60 Self {
61 encoder: codec.encoder(),
62 output_buffer: vec![0u8; OUTPUT_BUFFER_SIZE],
63 always_flush,
64 state: CompressState::Reading,
65 pending_trailers: None,
66 }
67 }
68
69 pub(crate) fn state(&self) -> CompressState {
71 self.state
72 }
73
74 #[allow(dead_code)]
76 pub(crate) fn always_flush(&self) -> bool {
77 self.always_flush
78 }
79
80 fn poll_compressed<B>(
82 &mut self,
83 cx: &mut Context<'_>,
84 mut inner: Pin<&mut B>,
85 ) -> Poll<Option<Result<Frame<Bytes>, io::Error>>>
86 where
87 B: Body,
88 B::Data: Buf,
89 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
90 {
91 loop {
92 match self.state {
93 CompressState::Done => return Poll::Ready(None),
94
95 CompressState::Trailers => {
96 if let Some(trailers) = self.pending_trailers.take() {
98 self.state = CompressState::Done;
99 return Poll::Ready(Some(Ok(Frame::trailers(trailers))));
100 } else {
101 self.state = CompressState::Done;
102 return Poll::Ready(None);
103 }
104 }
105
106 CompressState::Finishing => {
107 let mut output =
109 WriteBuffer::new_initialized(self.output_buffer.as_mut_slice());
110
111 match self.encoder.finish(&mut output) {
112 Ok(done) => {
113 let written = output.written_len();
114 if written > 0 {
115 let data = Bytes::copy_from_slice(&self.output_buffer[..written]);
116 if done {
117 self.state = if self.pending_trailers.is_some() {
118 CompressState::Trailers
119 } else {
120 CompressState::Done
121 };
122 }
123 return Poll::Ready(Some(Ok(Frame::data(data))));
124 } else if done {
125 self.state = if self.pending_trailers.is_some() {
126 CompressState::Trailers
127 } else {
128 CompressState::Done
129 };
130 continue;
131 }
132 }
134 Err(e) => {
135 return Poll::Ready(Some(Err(io::Error::other(e))));
136 }
137 }
138 }
139
140 CompressState::Reading => {
141 match inner.as_mut().poll_frame(cx) {
143 Poll::Pending => return Poll::Pending,
144 Poll::Ready(None) => {
145 self.state = CompressState::Finishing;
147 continue;
148 }
149 Poll::Ready(Some(Err(e))) => {
150 return Poll::Ready(Some(Err(io::Error::other(e.into()))));
151 }
152 Poll::Ready(Some(Ok(frame))) => {
153 match frame.into_data() {
154 Ok(mut data) => {
155 let input_bytes = data.copy_to_bytes(data.remaining());
157 return self.compress_chunk(&input_bytes);
158 }
159 Err(frame) => {
160 if let Ok(trailers) = frame.into_trailers() {
161 self.pending_trailers = Some(trailers);
163 self.state = CompressState::Finishing;
164 continue;
165 }
166 }
167 }
168 }
169 }
170 }
171 }
172 }
173 }
174
175 fn compress_chunk(&mut self, input: &[u8]) -> Poll<Option<Result<Frame<Bytes>, io::Error>>> {
177 let mut input_buf = PartialBuffer::new(input);
178 let mut all_output = BytesMut::new();
179
180 loop {
182 let mut output = WriteBuffer::new_initialized(self.output_buffer.as_mut_slice());
183
184 if let Err(e) = self.encoder.encode(&mut input_buf, &mut output) {
185 return Poll::Ready(Some(Err(io::Error::other(e))));
186 }
187
188 let written = output.written_len();
189 if written > 0 {
190 all_output.extend_from_slice(&self.output_buffer[..written]);
191 }
192
193 if input_buf.written_len() >= input.len() {
195 break;
196 }
197
198 if written == 0 && input_buf.written_len() == 0 {
200 break;
201 }
202 }
203
204 if self.always_flush {
206 loop {
207 let mut output = WriteBuffer::new_initialized(self.output_buffer.as_mut_slice());
208
209 match self.encoder.flush(&mut output) {
210 Ok(done) => {
211 let written = output.written_len();
212 if written > 0 {
213 all_output.extend_from_slice(&self.output_buffer[..written]);
214 }
215 if done {
216 break;
217 }
218 }
219 Err(e) => {
220 return Poll::Ready(Some(Err(io::Error::other(e))));
221 }
222 }
223 }
224 }
225
226 if all_output.is_empty() {
227 Poll::Pending
229 } else {
230 Poll::Ready(Some(Ok(Frame::data(all_output.freeze()))))
231 }
232 }
233}
234
235impl<B> CompressionBody<B> {
236 pub fn compressed(inner: B, codec: Codec, always_flush: bool) -> Self {
238 Self::Compressed {
239 inner,
240 state: CompressedBody::new(codec, always_flush),
241 }
242 }
243
244 pub fn passthrough(inner: B) -> Self {
246 Self::Passthrough { inner }
247 }
248}
249
250impl<B> Body for CompressionBody<B>
251where
252 B: Body,
253 B::Data: Buf,
254 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
255{
256 type Data = Bytes;
257 type Error = io::Error;
258
259 fn poll_frame(
260 self: Pin<&mut Self>,
261 cx: &mut Context<'_>,
262 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
263 match self.project() {
264 CompressionBodyProj::Passthrough { inner } => {
265 match inner.poll_frame(cx) {
267 Poll::Pending => Poll::Pending,
268 Poll::Ready(None) => Poll::Ready(None),
269 Poll::Ready(Some(Ok(frame))) => {
270 let frame = frame.map_data(|mut data| data.copy_to_bytes(data.remaining()));
271 Poll::Ready(Some(Ok(frame)))
272 }
273 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(io::Error::other(e.into())))),
274 }
275 }
276 CompressionBodyProj::Compressed { inner, state } => state.poll_compressed(cx, inner),
277 }
278 }
279
280 fn is_end_stream(&self) -> bool {
281 match self {
282 CompressionBody::Passthrough { inner } => inner.is_end_stream(),
283 CompressionBody::Compressed { state, .. } => state.state() == CompressState::Done,
284 }
285 }
286
287 fn size_hint(&self) -> http_body::SizeHint {
288 match self {
289 CompressionBody::Passthrough { inner } => inner.size_hint(),
290 CompressionBody::Compressed { .. } => http_body::SizeHint::default(),
292 }
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use http::HeaderMap;
300 use std::collections::VecDeque;
301
302 struct TestBody {
304 frames: VecDeque<Frame<Bytes>>,
305 }
306
307 impl TestBody {
308 fn new(frames: Vec<Frame<Bytes>>) -> Self {
309 Self {
310 frames: frames.into(),
311 }
312 }
313 }
314
315 impl Body for TestBody {
316 type Data = Bytes;
317 type Error = std::convert::Infallible;
318
319 fn poll_frame(
320 mut self: Pin<&mut Self>,
321 _cx: &mut Context<'_>,
322 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
323 match self.frames.pop_front() {
324 Some(frame) => Poll::Ready(Some(Ok(frame))),
325 None => Poll::Ready(None),
326 }
327 }
328 }
329
330 fn poll_body<B: Body + Unpin>(body: &mut B) -> Option<Result<Frame<B::Data>, B::Error>> {
331 let waker = std::task::Waker::noop();
332 let mut cx = Context::from_waker(waker);
333 match Pin::new(body).poll_frame(&mut cx) {
334 Poll::Ready(result) => result,
335 Poll::Pending => None,
336 }
337 }
338
339 #[test]
340 fn test_passthrough_data() {
341 let inner = TestBody::new(vec![Frame::data(Bytes::from("hello world"))]);
342 let mut body = CompressionBody::passthrough(inner);
343
344 let frame = poll_body(&mut body).unwrap().unwrap();
345 assert!(frame.is_data());
346 assert_eq!(frame.into_data().unwrap(), Bytes::from("hello world"));
347
348 assert!(poll_body(&mut body).is_none());
349 }
350
351 #[test]
352 fn test_passthrough_trailers() {
353 let mut trailers = HeaderMap::new();
354 trailers.insert("x-checksum", "abc123".parse().unwrap());
355
356 let inner = TestBody::new(vec![
357 Frame::data(Bytes::from("data")),
358 Frame::trailers(trailers.clone()),
359 ]);
360 let mut body = CompressionBody::passthrough(inner);
361
362 let frame = poll_body(&mut body).unwrap().unwrap();
364 assert!(frame.is_data());
365
366 let frame = poll_body(&mut body).unwrap().unwrap();
368 assert!(frame.is_trailers());
369 let received_trailers = frame.into_trailers().unwrap();
370 assert_eq!(received_trailers.get("x-checksum").unwrap(), "abc123");
371
372 assert!(poll_body(&mut body).is_none());
373 }
374
375 #[test]
376 #[cfg(feature = "gzip")]
377 fn test_compressed_produces_output() {
378 let inner = TestBody::new(vec![Frame::data(Bytes::from("hello world"))]);
379 let mut body = CompressionBody::compressed(inner, Codec::Gzip, false);
380
381 let frame = poll_body(&mut body).unwrap().unwrap();
383 assert!(frame.is_data());
384 let data = frame.into_data().unwrap();
385 assert!(!data.is_empty());
387
388 while let Some(Ok(frame)) = poll_body(&mut body) {
390 assert!(frame.is_data());
391 }
392 }
393
394 #[test]
395 #[cfg(feature = "gzip")]
396 fn test_compressed_with_trailers() {
397 let mut trailers = HeaderMap::new();
398 trailers.insert("x-checksum", "abc123".parse().unwrap());
399
400 let inner = TestBody::new(vec![
401 Frame::data(Bytes::from("hello world")),
402 Frame::trailers(trailers),
403 ]);
404 let mut body = CompressionBody::compressed(inner, Codec::Gzip, false);
405
406 let mut data_frames = 0;
408 let mut trailer_frame = None;
409 while let Some(Ok(frame)) = poll_body(&mut body) {
410 if frame.is_data() {
411 data_frames += 1;
412 } else if frame.is_trailers() {
413 trailer_frame = Some(frame);
414 }
415 }
416
417 assert!(data_frames >= 1);
419
420 let trailers = trailer_frame
422 .expect("Expected trailers frame")
423 .into_trailers()
424 .unwrap();
425 assert_eq!(trailers.get("x-checksum").unwrap(), "abc123");
426 }
427}