http_response_compression/
body.rs

1use crate::codec::Codec;
2use bytes::{Buf, Bytes, BytesMut};
3use compression_codecs::EncodeV2;
4use compression_core::util::{PartialBuffer, WriteBuffer};
5use http_body::{Body, Frame};
6use pin_project_lite::pin_project;
7use std::io;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10
11const OUTPUT_BUFFER_SIZE: usize = 8 * 1024; // 8KB output buffer
12
13pin_project! {
14    /// A response body that may be compressed.
15    ///
16    /// This type wraps an inner body and either compresses it using the
17    /// specified codec or passes it through unchanged.
18    #[project = CompressionBodyProj]
19    #[allow(missing_docs)]
20    pub enum CompressionBody<B> {
21        /// Compressed body with encoder.
22        Compressed {
23            #[pin]
24            inner: B,
25            state: CompressedBody,
26        },
27        /// Passthrough body without compression.
28        Passthrough {
29            #[pin]
30            inner: B,
31        },
32    }
33}
34
35/// State and buffers for an actively compressed body.
36pub(crate) struct CompressedBody {
37    encoder: Box<dyn EncodeV2 + Send>,
38    output_buffer: Vec<u8>,
39    always_flush: bool,
40    state: CompressState,
41    pending_trailers: Option<http::HeaderMap>,
42}
43
44/// State machine for compression.
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub(crate) enum CompressState {
47    /// Reading data from inner body and compressing.
48    Reading,
49    /// Finishing compression after inner body is done.
50    Finishing,
51    /// Emitting buffered trailers.
52    Trailers,
53    /// Compression is complete.
54    Done,
55}
56
57impl CompressedBody {
58    /// Creates a new compressed body state with the given codec.
59    fn new(codec: Codec, always_flush: bool) -> Self {
60        Self {
61            encoder: codec.encoder(),
62            output_buffer: vec![0u8; OUTPUT_BUFFER_SIZE],
63            always_flush,
64            state: CompressState::Reading,
65            pending_trailers: None,
66        }
67    }
68
69    /// Returns the current compression state.
70    pub(crate) fn state(&self) -> CompressState {
71        self.state
72    }
73
74    /// Returns whether always flush is enabled.
75    #[allow(dead_code)]
76    pub(crate) fn always_flush(&self) -> bool {
77        self.always_flush
78    }
79
80    /// Polls the inner body and compresses data.
81    fn poll_compressed<B>(
82        &mut self,
83        cx: &mut Context<'_>,
84        mut inner: Pin<&mut B>,
85    ) -> Poll<Option<Result<Frame<Bytes>, io::Error>>>
86    where
87        B: Body,
88        B::Data: Buf,
89        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
90    {
91        loop {
92            match self.state {
93                CompressState::Done => return Poll::Ready(None),
94
95                CompressState::Trailers => {
96                    // Emit buffered trailers
97                    if let Some(trailers) = self.pending_trailers.take() {
98                        self.state = CompressState::Done;
99                        return Poll::Ready(Some(Ok(Frame::trailers(trailers))));
100                    } else {
101                        self.state = CompressState::Done;
102                        return Poll::Ready(None);
103                    }
104                }
105
106                CompressState::Finishing => {
107                    // Finish the encoder
108                    let mut output =
109                        WriteBuffer::new_initialized(self.output_buffer.as_mut_slice());
110
111                    match self.encoder.finish(&mut output) {
112                        Ok(done) => {
113                            let written = output.written_len();
114                            if written > 0 {
115                                let data = Bytes::copy_from_slice(&self.output_buffer[..written]);
116                                if done {
117                                    self.state = if self.pending_trailers.is_some() {
118                                        CompressState::Trailers
119                                    } else {
120                                        CompressState::Done
121                                    };
122                                }
123                                return Poll::Ready(Some(Ok(Frame::data(data))));
124                            } else if done {
125                                self.state = if self.pending_trailers.is_some() {
126                                    CompressState::Trailers
127                                } else {
128                                    CompressState::Done
129                                };
130                                continue;
131                            }
132                            // Continue looping to finish
133                        }
134                        Err(e) => {
135                            return Poll::Ready(Some(Err(io::Error::other(e))));
136                        }
137                    }
138                }
139
140                CompressState::Reading => {
141                    // Poll inner body for data
142                    match inner.as_mut().poll_frame(cx) {
143                        Poll::Pending => return Poll::Pending,
144                        Poll::Ready(None) => {
145                            // Inner body is done, transition to finishing
146                            self.state = CompressState::Finishing;
147                            continue;
148                        }
149                        Poll::Ready(Some(Err(e))) => {
150                            return Poll::Ready(Some(Err(io::Error::other(e.into()))));
151                        }
152                        Poll::Ready(Some(Ok(frame))) => {
153                            match frame.into_data() {
154                                Ok(mut data) => {
155                                    // Compress the data
156                                    let input_bytes = data.copy_to_bytes(data.remaining());
157                                    return self.compress_chunk(&input_bytes);
158                                }
159                                Err(frame) => {
160                                    if let Ok(trailers) = frame.into_trailers() {
161                                        // Buffer trailers and finish compression first
162                                        self.pending_trailers = Some(trailers);
163                                        self.state = CompressState::Finishing;
164                                        continue;
165                                    }
166                                }
167                            }
168                        }
169                    }
170                }
171            }
172        }
173    }
174
175    /// Compresses a chunk of input data.
176    fn compress_chunk(&mut self, input: &[u8]) -> Poll<Option<Result<Frame<Bytes>, io::Error>>> {
177        let mut input_buf = PartialBuffer::new(input);
178        let mut all_output = BytesMut::new();
179
180        // Keep encoding until all input is consumed
181        loop {
182            let mut output = WriteBuffer::new_initialized(self.output_buffer.as_mut_slice());
183
184            if let Err(e) = self.encoder.encode(&mut input_buf, &mut output) {
185                return Poll::Ready(Some(Err(io::Error::other(e))));
186            }
187
188            let written = output.written_len();
189            if written > 0 {
190                all_output.extend_from_slice(&self.output_buffer[..written]);
191            }
192
193            // Check if we've consumed all input
194            if input_buf.written_len() >= input.len() {
195                break;
196            }
197
198            // Safety check to prevent infinite loop
199            if written == 0 && input_buf.written_len() == 0 {
200                break;
201            }
202        }
203
204        // Flush if always_flush is enabled
205        if self.always_flush {
206            loop {
207                let mut output = WriteBuffer::new_initialized(self.output_buffer.as_mut_slice());
208
209                match self.encoder.flush(&mut output) {
210                    Ok(done) => {
211                        let written = output.written_len();
212                        if written > 0 {
213                            all_output.extend_from_slice(&self.output_buffer[..written]);
214                        }
215                        if done {
216                            break;
217                        }
218                    }
219                    Err(e) => {
220                        return Poll::Ready(Some(Err(io::Error::other(e))));
221                    }
222                }
223            }
224        }
225
226        if all_output.is_empty() {
227            // No output yet, need to continue polling
228            Poll::Pending
229        } else {
230            Poll::Ready(Some(Ok(Frame::data(all_output.freeze()))))
231        }
232    }
233}
234
235impl<B> CompressionBody<B> {
236    /// Creates a compressed body with the given codec.
237    pub fn compressed(inner: B, codec: Codec, always_flush: bool) -> Self {
238        Self::Compressed {
239            inner,
240            state: CompressedBody::new(codec, always_flush),
241        }
242    }
243
244    /// Creates a passthrough body without compression.
245    pub fn passthrough(inner: B) -> Self {
246        Self::Passthrough { inner }
247    }
248}
249
250impl<B> Body for CompressionBody<B>
251where
252    B: Body,
253    B::Data: Buf,
254    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
255{
256    type Data = Bytes;
257    type Error = io::Error;
258
259    fn poll_frame(
260        self: Pin<&mut Self>,
261        cx: &mut Context<'_>,
262    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
263        match self.project() {
264            CompressionBodyProj::Passthrough { inner } => {
265                // Pass through frames, converting data to Bytes
266                match inner.poll_frame(cx) {
267                    Poll::Pending => Poll::Pending,
268                    Poll::Ready(None) => Poll::Ready(None),
269                    Poll::Ready(Some(Ok(frame))) => {
270                        let frame = frame.map_data(|mut data| data.copy_to_bytes(data.remaining()));
271                        Poll::Ready(Some(Ok(frame)))
272                    }
273                    Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(io::Error::other(e.into())))),
274                }
275            }
276            CompressionBodyProj::Compressed { inner, state } => state.poll_compressed(cx, inner),
277        }
278    }
279
280    fn is_end_stream(&self) -> bool {
281        match self {
282            CompressionBody::Passthrough { inner } => inner.is_end_stream(),
283            CompressionBody::Compressed { state, .. } => state.state() == CompressState::Done,
284        }
285    }
286
287    fn size_hint(&self) -> http_body::SizeHint {
288        match self {
289            CompressionBody::Passthrough { inner } => inner.size_hint(),
290            // Compressed size is unknown
291            CompressionBody::Compressed { .. } => http_body::SizeHint::default(),
292        }
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use http::HeaderMap;
300    use std::collections::VecDeque;
301
302    /// A test body that yields predefined frames.
303    struct TestBody {
304        frames: VecDeque<Frame<Bytes>>,
305    }
306
307    impl TestBody {
308        fn new(frames: Vec<Frame<Bytes>>) -> Self {
309            Self {
310                frames: frames.into(),
311            }
312        }
313    }
314
315    impl Body for TestBody {
316        type Data = Bytes;
317        type Error = std::convert::Infallible;
318
319        fn poll_frame(
320            mut self: Pin<&mut Self>,
321            _cx: &mut Context<'_>,
322        ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
323            match self.frames.pop_front() {
324                Some(frame) => Poll::Ready(Some(Ok(frame))),
325                None => Poll::Ready(None),
326            }
327        }
328    }
329
330    fn poll_body<B: Body + Unpin>(body: &mut B) -> Option<Result<Frame<B::Data>, B::Error>> {
331        let waker = std::task::Waker::noop();
332        let mut cx = Context::from_waker(waker);
333        match Pin::new(body).poll_frame(&mut cx) {
334            Poll::Ready(result) => result,
335            Poll::Pending => None,
336        }
337    }
338
339    #[test]
340    fn test_passthrough_data() {
341        let inner = TestBody::new(vec![Frame::data(Bytes::from("hello world"))]);
342        let mut body = CompressionBody::passthrough(inner);
343
344        let frame = poll_body(&mut body).unwrap().unwrap();
345        assert!(frame.is_data());
346        assert_eq!(frame.into_data().unwrap(), Bytes::from("hello world"));
347
348        assert!(poll_body(&mut body).is_none());
349    }
350
351    #[test]
352    fn test_passthrough_trailers() {
353        let mut trailers = HeaderMap::new();
354        trailers.insert("x-checksum", "abc123".parse().unwrap());
355
356        let inner = TestBody::new(vec![
357            Frame::data(Bytes::from("data")),
358            Frame::trailers(trailers.clone()),
359        ]);
360        let mut body = CompressionBody::passthrough(inner);
361
362        // First frame is data
363        let frame = poll_body(&mut body).unwrap().unwrap();
364        assert!(frame.is_data());
365
366        // Second frame is trailers
367        let frame = poll_body(&mut body).unwrap().unwrap();
368        assert!(frame.is_trailers());
369        let received_trailers = frame.into_trailers().unwrap();
370        assert_eq!(received_trailers.get("x-checksum").unwrap(), "abc123");
371
372        assert!(poll_body(&mut body).is_none());
373    }
374
375    #[test]
376    #[cfg(feature = "gzip")]
377    fn test_compressed_produces_output() {
378        let inner = TestBody::new(vec![Frame::data(Bytes::from("hello world"))]);
379        let mut body = CompressionBody::compressed(inner, Codec::Gzip, false);
380
381        // Should get compressed data
382        let frame = poll_body(&mut body).unwrap().unwrap();
383        assert!(frame.is_data());
384        let data = frame.into_data().unwrap();
385        // Compressed output should exist (gzip header starts with 0x1f 0x8b)
386        assert!(!data.is_empty());
387
388        // Should get more data from finishing
389        while let Some(Ok(frame)) = poll_body(&mut body) {
390            assert!(frame.is_data());
391        }
392    }
393
394    #[test]
395    #[cfg(feature = "gzip")]
396    fn test_compressed_with_trailers() {
397        let mut trailers = HeaderMap::new();
398        trailers.insert("x-checksum", "abc123".parse().unwrap());
399
400        let inner = TestBody::new(vec![
401            Frame::data(Bytes::from("hello world")),
402            Frame::trailers(trailers),
403        ]);
404        let mut body = CompressionBody::compressed(inner, Codec::Gzip, false);
405
406        // Collect all frames
407        let mut data_frames = 0;
408        let mut trailer_frame = None;
409        while let Some(Ok(frame)) = poll_body(&mut body) {
410            if frame.is_data() {
411                data_frames += 1;
412            } else if frame.is_trailers() {
413                trailer_frame = Some(frame);
414            }
415        }
416
417        // Should have received at least one data frame
418        assert!(data_frames >= 1);
419
420        // Should have received trailers
421        let trailers = trailer_frame
422            .expect("Expected trailers frame")
423            .into_trailers()
424            .unwrap();
425        assert_eq!(trailers.get("x-checksum").unwrap(), "abc123");
426    }
427}