http_response_compression/
body.rs

1use crate::codec::Codec;
2use bytes::{Buf, Bytes, BytesMut};
3use compression_codecs::EncodeV2;
4use compression_core::util::{PartialBuffer, WriteBuffer};
5use http_body::{Body, Frame};
6use pin_project_lite::pin_project;
7use std::io;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10
11const OUTPUT_BUFFER_SIZE: usize = 8 * 1024; // 8KB output buffer
12
13pin_project! {
14    /// A response body that may be compressed.
15    ///
16    /// This type wraps an inner body and either compresses it using the
17    /// specified codec or passes it through unchanged.
18    #[project = CompressionBodyProj]
19    #[allow(missing_docs)]
20    pub enum CompressionBody<B> {
21        /// Compressed body with encoder.
22        Compressed {
23            #[pin]
24            inner: B,
25            state: CompressedBody,
26        },
27        /// Passthrough body without compression.
28        Passthrough {
29            #[pin]
30            inner: B,
31        },
32    }
33}
34
35/// State and buffers for an actively compressed body.
36pub(crate) struct CompressedBody {
37    encoder: Box<dyn EncodeV2 + Send>,
38    output_buffer: Vec<u8>,
39    always_flush: bool,
40    state: CompressState,
41    pending_trailers: Option<http::HeaderMap>,
42}
43
44/// State machine for compression.
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub(crate) enum CompressState {
47    /// Reading data from inner body and compressing.
48    Reading,
49    /// Finishing compression after inner body is done.
50    Finishing,
51    /// Emitting buffered trailers.
52    Trailers,
53    /// Compression is complete.
54    Done,
55}
56
57impl CompressedBody {
58    /// Creates a new compressed body state with the given codec.
59    fn new(codec: Codec, always_flush: bool) -> Self {
60        Self {
61            encoder: codec.encoder(),
62            output_buffer: vec![0u8; OUTPUT_BUFFER_SIZE],
63            always_flush,
64            state: CompressState::Reading,
65            pending_trailers: None,
66        }
67    }
68
69    /// Returns the current compression state.
70    pub(crate) fn state(&self) -> CompressState {
71        self.state
72    }
73
74    /// Returns whether always flush is enabled.
75    #[allow(dead_code)]
76    pub(crate) fn always_flush(&self) -> bool {
77        self.always_flush
78    }
79
80    /// Polls the inner body and compresses data.
81    fn poll_compressed<B>(
82        &mut self,
83        cx: &mut Context<'_>,
84        mut inner: Pin<&mut B>,
85    ) -> Poll<Option<Result<Frame<Bytes>, io::Error>>>
86    where
87        B: Body,
88        B::Data: Buf,
89        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
90    {
91        loop {
92            match self.state {
93                CompressState::Done => return Poll::Ready(None),
94
95                CompressState::Trailers => {
96                    // Emit buffered trailers
97                    if let Some(trailers) = self.pending_trailers.take() {
98                        self.state = CompressState::Done;
99                        return Poll::Ready(Some(Ok(Frame::trailers(trailers))));
100                    } else {
101                        self.state = CompressState::Done;
102                        return Poll::Ready(None);
103                    }
104                }
105
106                CompressState::Finishing => {
107                    // Finish the encoder
108                    let mut output =
109                        WriteBuffer::new_initialized(self.output_buffer.as_mut_slice());
110
111                    match self.encoder.finish(&mut output) {
112                        Ok(done) => {
113                            let written = output.written_len();
114                            if written > 0 {
115                                let data = Bytes::copy_from_slice(&self.output_buffer[..written]);
116                                if done {
117                                    self.state = if self.pending_trailers.is_some() {
118                                        CompressState::Trailers
119                                    } else {
120                                        CompressState::Done
121                                    };
122                                }
123                                return Poll::Ready(Some(Ok(Frame::data(data))));
124                            } else if done {
125                                self.state = if self.pending_trailers.is_some() {
126                                    CompressState::Trailers
127                                } else {
128                                    CompressState::Done
129                                };
130                                continue;
131                            }
132                            // Continue looping to finish
133                        }
134                        Err(e) => {
135                            return Poll::Ready(Some(Err(io::Error::other(e))));
136                        }
137                    }
138                }
139
140                CompressState::Reading => {
141                    // Poll inner body for data
142                    match inner.as_mut().poll_frame(cx) {
143                        Poll::Pending => return Poll::Pending,
144                        Poll::Ready(None) => {
145                            // Inner body is done, transition to finishing
146                            self.state = CompressState::Finishing;
147                            continue;
148                        }
149                        Poll::Ready(Some(Err(e))) => {
150                            return Poll::Ready(Some(Err(io::Error::other(e.into()))));
151                        }
152                        Poll::Ready(Some(Ok(frame))) => {
153                            if let Some(data) = frame.data_ref() {
154                                // Compress the data
155                                let input_bytes = collect_bytes(data);
156                                return self.compress_chunk(&input_bytes);
157                            } else if let Ok(trailers) = frame.into_trailers() {
158                                // Buffer trailers and finish compression first
159                                self.pending_trailers = Some(trailers);
160                                self.state = CompressState::Finishing;
161                                continue;
162                            }
163                        }
164                    }
165                }
166            }
167        }
168    }
169
170    /// Compresses a chunk of input data.
171    fn compress_chunk(&mut self, input: &[u8]) -> Poll<Option<Result<Frame<Bytes>, io::Error>>> {
172        let mut input_buf = PartialBuffer::new(input);
173        let mut all_output = BytesMut::new();
174
175        // Keep encoding until all input is consumed
176        loop {
177            let mut output = WriteBuffer::new_initialized(self.output_buffer.as_mut_slice());
178
179            if let Err(e) = self.encoder.encode(&mut input_buf, &mut output) {
180                return Poll::Ready(Some(Err(io::Error::other(e))));
181            }
182
183            let written = output.written_len();
184            if written > 0 {
185                all_output.extend_from_slice(&self.output_buffer[..written]);
186            }
187
188            // Check if we've consumed all input
189            if input_buf.written_len() >= input.len() {
190                break;
191            }
192
193            // Safety check to prevent infinite loop
194            if written == 0 && input_buf.written_len() == 0 {
195                break;
196            }
197        }
198
199        // Flush if always_flush is enabled
200        if self.always_flush {
201            loop {
202                let mut output = WriteBuffer::new_initialized(self.output_buffer.as_mut_slice());
203
204                match self.encoder.flush(&mut output) {
205                    Ok(done) => {
206                        let written = output.written_len();
207                        if written > 0 {
208                            all_output.extend_from_slice(&self.output_buffer[..written]);
209                        }
210                        if done {
211                            break;
212                        }
213                    }
214                    Err(e) => {
215                        return Poll::Ready(Some(Err(io::Error::other(e))));
216                    }
217                }
218            }
219        }
220
221        if all_output.is_empty() {
222            // No output yet, need to continue polling
223            Poll::Pending
224        } else {
225            Poll::Ready(Some(Ok(Frame::data(all_output.freeze()))))
226        }
227    }
228}
229
230impl<B> CompressionBody<B> {
231    /// Creates a compressed body with the given codec.
232    pub fn compressed(inner: B, codec: Codec, always_flush: bool) -> Self {
233        Self::Compressed {
234            inner,
235            state: CompressedBody::new(codec, always_flush),
236        }
237    }
238
239    /// Creates a passthrough body without compression.
240    pub fn passthrough(inner: B) -> Self {
241        Self::Passthrough { inner }
242    }
243}
244
245impl<B> Body for CompressionBody<B>
246where
247    B: Body,
248    B::Data: Buf,
249    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
250{
251    type Data = Bytes;
252    type Error = io::Error;
253
254    fn poll_frame(
255        self: Pin<&mut Self>,
256        cx: &mut Context<'_>,
257    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
258        match self.project() {
259            CompressionBodyProj::Passthrough { inner } => {
260                // Pass through frames, converting data to Bytes
261                match inner.poll_frame(cx) {
262                    Poll::Pending => Poll::Pending,
263                    Poll::Ready(None) => Poll::Ready(None),
264                    Poll::Ready(Some(Ok(frame))) => {
265                        let frame = frame.map_data(|data| {
266                            let mut bytes = BytesMut::with_capacity(data.remaining());
267                            let mut chunk = data;
268                            while chunk.has_remaining() {
269                                let slice = chunk.chunk();
270                                bytes.extend_from_slice(slice);
271                                chunk.advance(slice.len());
272                            }
273                            bytes.freeze()
274                        });
275                        Poll::Ready(Some(Ok(frame)))
276                    }
277                    Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(io::Error::other(e.into())))),
278                }
279            }
280            CompressionBodyProj::Compressed { inner, state } => state.poll_compressed(cx, inner),
281        }
282    }
283
284    fn is_end_stream(&self) -> bool {
285        match self {
286            CompressionBody::Passthrough { inner } => inner.is_end_stream(),
287            CompressionBody::Compressed { state, .. } => state.state() == CompressState::Done,
288        }
289    }
290
291    fn size_hint(&self) -> http_body::SizeHint {
292        match self {
293            CompressionBody::Passthrough { inner } => inner.size_hint(),
294            // Compressed size is unknown
295            CompressionBody::Compressed { .. } => http_body::SizeHint::default(),
296        }
297    }
298}
299
300fn collect_bytes<D: Buf>(data: &D) -> Vec<u8> {
301    let mut bytes = Vec::with_capacity(data.remaining());
302    let chunk = data.chunk();
303    let remaining = data.remaining();
304    let len = chunk.len().min(remaining);
305    bytes.extend_from_slice(&chunk[..len]);
306    bytes
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312    use http::HeaderMap;
313    use std::collections::VecDeque;
314
315    /// A test body that yields predefined frames.
316    struct TestBody {
317        frames: VecDeque<Frame<Bytes>>,
318    }
319
320    impl TestBody {
321        fn new(frames: Vec<Frame<Bytes>>) -> Self {
322            Self {
323                frames: frames.into(),
324            }
325        }
326    }
327
328    impl Body for TestBody {
329        type Data = Bytes;
330        type Error = std::convert::Infallible;
331
332        fn poll_frame(
333            mut self: Pin<&mut Self>,
334            _cx: &mut Context<'_>,
335        ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
336            match self.frames.pop_front() {
337                Some(frame) => Poll::Ready(Some(Ok(frame))),
338                None => Poll::Ready(None),
339            }
340        }
341    }
342
343    fn poll_body<B: Body + Unpin>(body: &mut B) -> Option<Result<Frame<B::Data>, B::Error>> {
344        let waker = std::task::Waker::noop();
345        let mut cx = Context::from_waker(waker);
346        match Pin::new(body).poll_frame(&mut cx) {
347            Poll::Ready(result) => result,
348            Poll::Pending => None,
349        }
350    }
351
352    #[test]
353    fn test_passthrough_data() {
354        let inner = TestBody::new(vec![Frame::data(Bytes::from("hello world"))]);
355        let mut body = CompressionBody::passthrough(inner);
356
357        let frame = poll_body(&mut body).unwrap().unwrap();
358        assert!(frame.is_data());
359        assert_eq!(frame.into_data().unwrap(), Bytes::from("hello world"));
360
361        assert!(poll_body(&mut body).is_none());
362    }
363
364    #[test]
365    fn test_passthrough_trailers() {
366        let mut trailers = HeaderMap::new();
367        trailers.insert("x-checksum", "abc123".parse().unwrap());
368
369        let inner = TestBody::new(vec![
370            Frame::data(Bytes::from("data")),
371            Frame::trailers(trailers.clone()),
372        ]);
373        let mut body = CompressionBody::passthrough(inner);
374
375        // First frame is data
376        let frame = poll_body(&mut body).unwrap().unwrap();
377        assert!(frame.is_data());
378
379        // Second frame is trailers
380        let frame = poll_body(&mut body).unwrap().unwrap();
381        assert!(frame.is_trailers());
382        let received_trailers = frame.into_trailers().unwrap();
383        assert_eq!(received_trailers.get("x-checksum").unwrap(), "abc123");
384
385        assert!(poll_body(&mut body).is_none());
386    }
387
388    #[test]
389    #[cfg(feature = "gzip")]
390    fn test_compressed_produces_output() {
391        let inner = TestBody::new(vec![Frame::data(Bytes::from("hello world"))]);
392        let mut body = CompressionBody::compressed(inner, Codec::Gzip, false);
393
394        // Should get compressed data
395        let frame = poll_body(&mut body).unwrap().unwrap();
396        assert!(frame.is_data());
397        let data = frame.into_data().unwrap();
398        // Compressed output should exist (gzip header starts with 0x1f 0x8b)
399        assert!(!data.is_empty());
400
401        // Should get more data from finishing
402        while let Some(Ok(frame)) = poll_body(&mut body) {
403            assert!(frame.is_data());
404        }
405    }
406
407    #[test]
408    #[cfg(feature = "gzip")]
409    fn test_compressed_with_trailers() {
410        let mut trailers = HeaderMap::new();
411        trailers.insert("x-checksum", "abc123".parse().unwrap());
412
413        let inner = TestBody::new(vec![
414            Frame::data(Bytes::from("hello world")),
415            Frame::trailers(trailers),
416        ]);
417        let mut body = CompressionBody::compressed(inner, Codec::Gzip, false);
418
419        // Collect all frames
420        let mut data_frames = 0;
421        let mut trailer_frame = None;
422        while let Some(Ok(frame)) = poll_body(&mut body) {
423            if frame.is_data() {
424                data_frames += 1;
425            } else if frame.is_trailers() {
426                trailer_frame = Some(frame);
427            }
428        }
429
430        // Should have received at least one data frame
431        assert!(data_frames >= 1);
432
433        // Should have received trailers
434        let trailers = trailer_frame
435            .expect("Expected trailers frame")
436            .into_trailers()
437            .unwrap();
438        assert_eq!(trailers.get("x-checksum").unwrap(), "abc123");
439    }
440}