rama_http/layer/decompression/
body.rs

1#![allow(unused_imports)]
2
3use crate::HeaderMap;
4use crate::dep::http_body::{Body, Frame, SizeHint};
5use crate::layer::util::compression::{
6    AsyncReadBody, BodyIntoStream, CompressionLevel, DecorateAsyncRead, WrapBody,
7};
8use rama_core::error::BoxError;
9
10use async_compression::tokio::bufread::BrotliDecoder;
11use async_compression::tokio::bufread::GzipDecoder;
12use async_compression::tokio::bufread::ZlibDecoder;
13use async_compression::tokio::bufread::ZstdDecoder;
14use bytes::{Buf, Bytes};
15use futures_lite::ready;
16use pin_project_lite::pin_project;
17use std::task::Context;
18use std::{io, marker::PhantomData, pin::Pin, task::Poll};
19use tokio_util::io::StreamReader;
20
21pin_project! {
22    /// Response body of [`RequestDecompression`] and [`Decompression`].
23    ///
24    /// [`RequestDecompression`]: super::RequestDecompression
25    /// [`Decompression`]: super::Decompression
26    pub struct DecompressionBody<B>
27    where
28        B: Body
29    {
30        #[pin]
31        pub(crate) inner: BodyInner<B>,
32    }
33}
34
35impl<B> Default for DecompressionBody<B>
36where
37    B: Body + Default,
38{
39    fn default() -> Self {
40        Self {
41            inner: BodyInner::Identity {
42                inner: B::default(),
43            },
44        }
45    }
46}
47
48impl<B> DecompressionBody<B>
49where
50    B: Body,
51{
52    pub(crate) fn new(inner: BodyInner<B>) -> Self {
53        Self { inner }
54    }
55}
56
57type GzipBody<B> = WrapBody<GzipDecoder<B>>;
58type DeflateBody<B> = WrapBody<ZlibDecoder<B>>;
59type BrotliBody<B> = WrapBody<BrotliDecoder<B>>;
60type ZstdBody<B> = WrapBody<ZstdDecoder<B>>;
61
62pin_project! {
63    #[project = BodyInnerProj]
64    pub(crate) enum BodyInner<B>
65    where
66        B: Body,
67    {
68        Gzip {
69            #[pin]
70            inner: GzipBody<B>,
71        },
72        Deflate {
73            #[pin]
74            inner: DeflateBody<B>,
75        },
76        Brotli {
77            #[pin]
78            inner: BrotliBody<B>,
79        },
80        Zstd {
81            #[pin]
82            inner: ZstdBody<B>,
83        },
84        Identity {
85            #[pin]
86            inner: B,
87        },
88    }
89}
90
91impl<B: Body> BodyInner<B> {
92    pub(crate) fn gzip(inner: WrapBody<GzipDecoder<B>>) -> Self {
93        Self::Gzip { inner }
94    }
95
96    pub(crate) fn deflate(inner: WrapBody<ZlibDecoder<B>>) -> Self {
97        Self::Deflate { inner }
98    }
99
100    pub(crate) fn brotli(inner: WrapBody<BrotliDecoder<B>>) -> Self {
101        Self::Brotli { inner }
102    }
103
104    pub(crate) fn zstd(inner: WrapBody<ZstdDecoder<B>>) -> Self {
105        Self::Zstd { inner }
106    }
107
108    pub(crate) fn identity(inner: B) -> Self {
109        Self::Identity { inner }
110    }
111}
112
113impl<B> Body for DecompressionBody<B>
114where
115    B: Body<Error: Into<BoxError>>,
116{
117    type Data = Bytes;
118    type Error = BoxError;
119
120    fn poll_frame(
121        self: Pin<&mut Self>,
122        cx: &mut Context<'_>,
123    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
124        match self.project().inner.project() {
125            BodyInnerProj::Gzip { inner } => inner.poll_frame(cx),
126            BodyInnerProj::Deflate { inner } => inner.poll_frame(cx),
127            BodyInnerProj::Brotli { inner } => inner.poll_frame(cx),
128            BodyInnerProj::Zstd { inner } => inner.poll_frame(cx),
129            BodyInnerProj::Identity { inner } => match ready!(inner.poll_frame(cx)) {
130                Some(Ok(frame)) => {
131                    let frame = frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining()));
132                    Poll::Ready(Some(Ok(frame)))
133                }
134                Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
135                None => Poll::Ready(None),
136            },
137        }
138    }
139
140    fn size_hint(&self) -> SizeHint {
141        match self.inner {
142            BodyInner::Identity { ref inner } => inner.size_hint(),
143            _ => SizeHint::default(),
144        }
145    }
146}
147
148impl<B> DecorateAsyncRead for GzipDecoder<B>
149where
150    B: Body,
151{
152    type Input = AsyncReadBody<B>;
153    type Output = GzipDecoder<Self::Input>;
154
155    fn apply(input: Self::Input, _quality: CompressionLevel) -> Self::Output {
156        let mut decoder = GzipDecoder::new(input);
157        decoder.multiple_members(true);
158        decoder
159    }
160
161    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
162        pinned.get_pin_mut()
163    }
164}
165
166impl<B> DecorateAsyncRead for ZlibDecoder<B>
167where
168    B: Body,
169{
170    type Input = AsyncReadBody<B>;
171    type Output = ZlibDecoder<Self::Input>;
172
173    fn apply(input: Self::Input, _quality: CompressionLevel) -> Self::Output {
174        ZlibDecoder::new(input)
175    }
176
177    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
178        pinned.get_pin_mut()
179    }
180}
181
182impl<B> DecorateAsyncRead for BrotliDecoder<B>
183where
184    B: Body,
185{
186    type Input = AsyncReadBody<B>;
187    type Output = BrotliDecoder<Self::Input>;
188
189    fn apply(input: Self::Input, _quality: CompressionLevel) -> Self::Output {
190        BrotliDecoder::new(input)
191    }
192
193    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
194        pinned.get_pin_mut()
195    }
196}
197
198impl<B> DecorateAsyncRead for ZstdDecoder<B>
199where
200    B: Body,
201{
202    type Input = AsyncReadBody<B>;
203    type Output = ZstdDecoder<Self::Input>;
204
205    fn apply(input: Self::Input, _quality: CompressionLevel) -> Self::Output {
206        let mut decoder = ZstdDecoder::new(input);
207        decoder.multiple_members(true);
208        decoder
209    }
210
211    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
212        pinned.get_pin_mut()
213    }
214}