http_response_compression/
body.rs1use crate::codec::Codec;
2use bytes::{Buf, Bytes, BytesMut};
3use compression_codecs::EncodeV2;
4use compression_core::util::{PartialBuffer, WriteBuffer};
5use http_body::{Body, Frame};
6use pin_project_lite::pin_project;
7use std::io;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10
11const OUTPUT_BUFFER_SIZE: usize = 8 * 1024; pin_project! {
14 #[project = CompressionBodyProj]
19 #[allow(missing_docs)]
20 pub enum CompressionBody<B> {
21 Compressed {
23 #[pin]
24 inner: B,
25 state: CompressedBody,
26 },
27 Passthrough {
29 #[pin]
30 inner: B,
31 },
32 }
33}
34
35pub(crate) struct CompressedBody {
37 encoder: Box<dyn EncodeV2 + Send>,
38 output_buffer: Vec<u8>,
39 always_flush: bool,
40 state: CompressState,
41 pending_trailers: Option<http::HeaderMap>,
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub(crate) enum CompressState {
47 Reading,
49 Finishing,
51 Trailers,
53 Done,
55}
56
57impl CompressedBody {
58 fn new(codec: Codec, always_flush: bool) -> Self {
60 Self {
61 encoder: codec.encoder(),
62 output_buffer: vec![0u8; OUTPUT_BUFFER_SIZE],
63 always_flush,
64 state: CompressState::Reading,
65 pending_trailers: None,
66 }
67 }
68
69 pub(crate) fn state(&self) -> CompressState {
71 self.state
72 }
73
74 #[allow(dead_code)]
76 pub(crate) fn always_flush(&self) -> bool {
77 self.always_flush
78 }
79
80 fn poll_compressed<B>(
82 &mut self,
83 cx: &mut Context<'_>,
84 mut inner: Pin<&mut B>,
85 ) -> Poll<Option<Result<Frame<Bytes>, io::Error>>>
86 where
87 B: Body,
88 B::Data: Buf,
89 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
90 {
91 loop {
92 match self.state {
93 CompressState::Done => return Poll::Ready(None),
94
95 CompressState::Trailers => {
96 if let Some(trailers) = self.pending_trailers.take() {
98 self.state = CompressState::Done;
99 return Poll::Ready(Some(Ok(Frame::trailers(trailers))));
100 } else {
101 self.state = CompressState::Done;
102 return Poll::Ready(None);
103 }
104 }
105
106 CompressState::Finishing => {
107 let mut output =
109 WriteBuffer::new_initialized(self.output_buffer.as_mut_slice());
110
111 match self.encoder.finish(&mut output) {
112 Ok(done) => {
113 let written = output.written_len();
114 if written > 0 {
115 let data = Bytes::copy_from_slice(&self.output_buffer[..written]);
116 if done {
117 self.state = if self.pending_trailers.is_some() {
118 CompressState::Trailers
119 } else {
120 CompressState::Done
121 };
122 }
123 return Poll::Ready(Some(Ok(Frame::data(data))));
124 } else if done {
125 self.state = if self.pending_trailers.is_some() {
126 CompressState::Trailers
127 } else {
128 CompressState::Done
129 };
130 continue;
131 }
132 }
134 Err(e) => {
135 return Poll::Ready(Some(Err(io::Error::other(e))));
136 }
137 }
138 }
139
140 CompressState::Reading => {
141 match inner.as_mut().poll_frame(cx) {
143 Poll::Pending => return Poll::Pending,
144 Poll::Ready(None) => {
145 self.state = CompressState::Finishing;
147 continue;
148 }
149 Poll::Ready(Some(Err(e))) => {
150 return Poll::Ready(Some(Err(io::Error::other(e.into()))));
151 }
152 Poll::Ready(Some(Ok(frame))) => {
153 if let Some(data) = frame.data_ref() {
154 let input_bytes = collect_bytes(data);
156 return self.compress_chunk(&input_bytes);
157 } else if let Ok(trailers) = frame.into_trailers() {
158 self.pending_trailers = Some(trailers);
160 self.state = CompressState::Finishing;
161 continue;
162 }
163 }
164 }
165 }
166 }
167 }
168 }
169
170 fn compress_chunk(&mut self, input: &[u8]) -> Poll<Option<Result<Frame<Bytes>, io::Error>>> {
172 let mut input_buf = PartialBuffer::new(input);
173 let mut all_output = BytesMut::new();
174
175 loop {
177 let mut output = WriteBuffer::new_initialized(self.output_buffer.as_mut_slice());
178
179 if let Err(e) = self.encoder.encode(&mut input_buf, &mut output) {
180 return Poll::Ready(Some(Err(io::Error::other(e))));
181 }
182
183 let written = output.written_len();
184 if written > 0 {
185 all_output.extend_from_slice(&self.output_buffer[..written]);
186 }
187
188 if input_buf.written_len() >= input.len() {
190 break;
191 }
192
193 if written == 0 && input_buf.written_len() == 0 {
195 break;
196 }
197 }
198
199 if self.always_flush {
201 loop {
202 let mut output = WriteBuffer::new_initialized(self.output_buffer.as_mut_slice());
203
204 match self.encoder.flush(&mut output) {
205 Ok(done) => {
206 let written = output.written_len();
207 if written > 0 {
208 all_output.extend_from_slice(&self.output_buffer[..written]);
209 }
210 if done {
211 break;
212 }
213 }
214 Err(e) => {
215 return Poll::Ready(Some(Err(io::Error::other(e))));
216 }
217 }
218 }
219 }
220
221 if all_output.is_empty() {
222 Poll::Pending
224 } else {
225 Poll::Ready(Some(Ok(Frame::data(all_output.freeze()))))
226 }
227 }
228}
229
230impl<B> CompressionBody<B> {
231 pub fn compressed(inner: B, codec: Codec, always_flush: bool) -> Self {
233 Self::Compressed {
234 inner,
235 state: CompressedBody::new(codec, always_flush),
236 }
237 }
238
239 pub fn passthrough(inner: B) -> Self {
241 Self::Passthrough { inner }
242 }
243}
244
245impl<B> Body for CompressionBody<B>
246where
247 B: Body,
248 B::Data: Buf,
249 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
250{
251 type Data = Bytes;
252 type Error = io::Error;
253
254 fn poll_frame(
255 self: Pin<&mut Self>,
256 cx: &mut Context<'_>,
257 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
258 match self.project() {
259 CompressionBodyProj::Passthrough { inner } => {
260 match inner.poll_frame(cx) {
262 Poll::Pending => Poll::Pending,
263 Poll::Ready(None) => Poll::Ready(None),
264 Poll::Ready(Some(Ok(frame))) => {
265 let frame = frame.map_data(|data| {
266 let mut bytes = BytesMut::with_capacity(data.remaining());
267 let mut chunk = data;
268 while chunk.has_remaining() {
269 let slice = chunk.chunk();
270 bytes.extend_from_slice(slice);
271 chunk.advance(slice.len());
272 }
273 bytes.freeze()
274 });
275 Poll::Ready(Some(Ok(frame)))
276 }
277 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(io::Error::other(e.into())))),
278 }
279 }
280 CompressionBodyProj::Compressed { inner, state } => state.poll_compressed(cx, inner),
281 }
282 }
283
284 fn is_end_stream(&self) -> bool {
285 match self {
286 CompressionBody::Passthrough { inner } => inner.is_end_stream(),
287 CompressionBody::Compressed { state, .. } => state.state() == CompressState::Done,
288 }
289 }
290
291 fn size_hint(&self) -> http_body::SizeHint {
292 match self {
293 CompressionBody::Passthrough { inner } => inner.size_hint(),
294 CompressionBody::Compressed { .. } => http_body::SizeHint::default(),
296 }
297 }
298}
299
300fn collect_bytes<D: Buf>(data: &D) -> Vec<u8> {
301 let mut bytes = Vec::with_capacity(data.remaining());
302 let chunk = data.chunk();
303 let remaining = data.remaining();
304 let len = chunk.len().min(remaining);
305 bytes.extend_from_slice(&chunk[..len]);
306 bytes
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312 use http::HeaderMap;
313 use std::collections::VecDeque;
314
315 struct TestBody {
317 frames: VecDeque<Frame<Bytes>>,
318 }
319
320 impl TestBody {
321 fn new(frames: Vec<Frame<Bytes>>) -> Self {
322 Self {
323 frames: frames.into(),
324 }
325 }
326 }
327
328 impl Body for TestBody {
329 type Data = Bytes;
330 type Error = std::convert::Infallible;
331
332 fn poll_frame(
333 mut self: Pin<&mut Self>,
334 _cx: &mut Context<'_>,
335 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
336 match self.frames.pop_front() {
337 Some(frame) => Poll::Ready(Some(Ok(frame))),
338 None => Poll::Ready(None),
339 }
340 }
341 }
342
343 fn poll_body<B: Body + Unpin>(body: &mut B) -> Option<Result<Frame<B::Data>, B::Error>> {
344 let waker = std::task::Waker::noop();
345 let mut cx = Context::from_waker(waker);
346 match Pin::new(body).poll_frame(&mut cx) {
347 Poll::Ready(result) => result,
348 Poll::Pending => None,
349 }
350 }
351
352 #[test]
353 fn test_passthrough_data() {
354 let inner = TestBody::new(vec![Frame::data(Bytes::from("hello world"))]);
355 let mut body = CompressionBody::passthrough(inner);
356
357 let frame = poll_body(&mut body).unwrap().unwrap();
358 assert!(frame.is_data());
359 assert_eq!(frame.into_data().unwrap(), Bytes::from("hello world"));
360
361 assert!(poll_body(&mut body).is_none());
362 }
363
364 #[test]
365 fn test_passthrough_trailers() {
366 let mut trailers = HeaderMap::new();
367 trailers.insert("x-checksum", "abc123".parse().unwrap());
368
369 let inner = TestBody::new(vec![
370 Frame::data(Bytes::from("data")),
371 Frame::trailers(trailers.clone()),
372 ]);
373 let mut body = CompressionBody::passthrough(inner);
374
375 let frame = poll_body(&mut body).unwrap().unwrap();
377 assert!(frame.is_data());
378
379 let frame = poll_body(&mut body).unwrap().unwrap();
381 assert!(frame.is_trailers());
382 let received_trailers = frame.into_trailers().unwrap();
383 assert_eq!(received_trailers.get("x-checksum").unwrap(), "abc123");
384
385 assert!(poll_body(&mut body).is_none());
386 }
387
388 #[test]
389 #[cfg(feature = "gzip")]
390 fn test_compressed_produces_output() {
391 let inner = TestBody::new(vec![Frame::data(Bytes::from("hello world"))]);
392 let mut body = CompressionBody::compressed(inner, Codec::Gzip, false);
393
394 let frame = poll_body(&mut body).unwrap().unwrap();
396 assert!(frame.is_data());
397 let data = frame.into_data().unwrap();
398 assert!(!data.is_empty());
400
401 while let Some(Ok(frame)) = poll_body(&mut body) {
403 assert!(frame.is_data());
404 }
405 }
406
407 #[test]
408 #[cfg(feature = "gzip")]
409 fn test_compressed_with_trailers() {
410 let mut trailers = HeaderMap::new();
411 trailers.insert("x-checksum", "abc123".parse().unwrap());
412
413 let inner = TestBody::new(vec![
414 Frame::data(Bytes::from("hello world")),
415 Frame::trailers(trailers),
416 ]);
417 let mut body = CompressionBody::compressed(inner, Codec::Gzip, false);
418
419 let mut data_frames = 0;
421 let mut trailer_frame = None;
422 while let Some(Ok(frame)) = poll_body(&mut body) {
423 if frame.is_data() {
424 data_frames += 1;
425 } else if frame.is_trailers() {
426 trailer_frame = Some(frame);
427 }
428 }
429
430 assert!(data_frames >= 1);
432
433 let trailers = trailer_frame
435 .expect("Expected trailers frame")
436 .into_trailers()
437 .unwrap();
438 assert_eq!(trailers.get("x-checksum").unwrap(), "abc123");
439 }
440}