rama_http/layer/util/
compression.rs

1//! Types used by compression and decompression middleware.
2
3use crate::dep::http_body::{Body, Frame};
4use bytes::{Buf, Bytes, BytesMut};
5use futures_lite::Stream;
6use futures_lite::ready;
7use pin_project_lite::pin_project;
8use rama_core::error::BoxError;
9use std::{
10    io,
11    pin::Pin,
12    task::{Context, Poll},
13};
14use tokio::io::AsyncRead;
15use tokio_util::io::StreamReader;
16
17/// A `Body` that has been converted into an `AsyncRead`.
18pub(crate) type AsyncReadBody<B> =
19    StreamReader<StreamErrorIntoIoError<BodyIntoStream<B>, <B as Body>::Error>, <B as Body>::Data>;
20
21/// Trait for applying some decorator to an `AsyncRead`
22pub(crate) trait DecorateAsyncRead {
23    type Input: AsyncRead;
24    type Output: AsyncRead;
25
26    /// Apply the decorator
27    fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output;
28
29    /// Get a pinned mutable reference to the original input.
30    ///
31    /// This is necessary to implement `Body::poll_trailers`.
32    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input>;
33}
34
35pin_project! {
36    /// `Body` that has been decorated by an `AsyncRead`
37    pub(crate) struct WrapBody<M: DecorateAsyncRead> {
38        #[pin]
39        // rust-analyser thinks this field is private if its `pub(crate)` but works fine when its
40        // `pub`
41        pub read: M::Output,
42        // A buffer to temporarily store the data read from the underlying body.
43        // Reused as much as possible to optimize allocations.
44        buf: BytesMut,
45        read_all_data: bool,
46    }
47}
48
49impl<M: DecorateAsyncRead> WrapBody<M> {
50    const INTERNAL_BUF_CAPACITY: usize = 4096;
51}
52
53impl<M: DecorateAsyncRead> WrapBody<M> {
54    #[allow(dead_code)]
55    pub(crate) fn new<B>(body: B, quality: CompressionLevel) -> Self
56    where
57        B: Body,
58        M: DecorateAsyncRead<Input = AsyncReadBody<B>>,
59    {
60        // convert `Body` into a `Stream`
61        let stream = BodyIntoStream::new(body);
62
63        // an adapter that converts the error type into `io::Error` while storing the actual error
64        // `StreamReader` requires the error type is `io::Error`
65        let stream = StreamErrorIntoIoError::<_, B::Error>::new(stream);
66
67        // convert `Stream` into an `AsyncRead`
68        let read = StreamReader::new(stream);
69
70        // apply decorator to `AsyncRead` yielding another `AsyncRead`
71        let read = M::apply(read, quality);
72
73        Self {
74            read,
75            buf: BytesMut::with_capacity(Self::INTERNAL_BUF_CAPACITY),
76            read_all_data: false,
77        }
78    }
79}
80
81impl<B, M> Body for WrapBody<M>
82where
83    B: Body<Error: Into<BoxError>>,
84    M: DecorateAsyncRead<Input = AsyncReadBody<B>>,
85{
86    type Data = Bytes;
87    type Error = BoxError;
88
89    fn poll_frame(
90        self: Pin<&mut Self>,
91        cx: &mut Context<'_>,
92    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
93        let mut this = self.project();
94
95        if !*this.read_all_data {
96            if this.buf.capacity() == 0 {
97                this.buf.reserve(Self::INTERNAL_BUF_CAPACITY);
98            }
99
100            let result = tokio_util::io::poll_read_buf(this.read.as_mut(), cx, &mut this.buf);
101
102            match ready!(result) {
103                Ok(0) => {
104                    *this.read_all_data = true;
105                }
106                Ok(_) => {
107                    let chunk = this.buf.split().freeze();
108                    return Poll::Ready(Some(Ok(Frame::data(chunk))));
109                }
110                Err(err) => {
111                    let body_error: Option<B::Error> = M::get_pin_mut(this.read)
112                        .get_pin_mut()
113                        .project()
114                        .error
115                        .take();
116
117                    if let Some(body_error) = body_error {
118                        return Poll::Ready(Some(Err(body_error.into())));
119                    } else if err.raw_os_error() == Some(SENTINEL_ERROR_CODE) {
120                        // SENTINEL_ERROR_CODE only gets used when storing
121                        // an underlying body error
122                        unreachable!()
123                    } else {
124                        return Poll::Ready(Some(Err(err.into())));
125                    }
126                }
127            }
128        }
129        // poll any remaining frames, such as trailers
130        let body = M::get_pin_mut(this.read).get_pin_mut().get_pin_mut();
131        body.poll_frame(cx).map(|option| {
132            option.map(|result| {
133                result
134                    .map(|frame| frame.map_data(|mut data| data.copy_to_bytes(data.remaining())))
135                    .map_err(|err| err.into())
136            })
137        })
138    }
139}
140
141pin_project! {
142    pub(crate) struct BodyIntoStream<B>
143    where
144        B: Body,
145    {
146        #[pin]
147        body: B,
148        yielded_all_data: bool,
149        non_data_frame: Option<Frame<B::Data>>,
150    }
151}
152
153#[allow(dead_code)]
154impl<B> BodyIntoStream<B>
155where
156    B: Body,
157{
158    pub(crate) fn new(body: B) -> Self {
159        Self {
160            body,
161            yielded_all_data: false,
162            non_data_frame: None,
163        }
164    }
165
166    /// Get a reference to the inner body
167    pub(crate) fn get_ref(&self) -> &B {
168        &self.body
169    }
170
171    /// Get a mutable reference to the inner body
172    pub(crate) fn get_mut(&mut self) -> &mut B {
173        &mut self.body
174    }
175
176    /// Get a pinned mutable reference to the inner body
177    pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut B> {
178        self.project().body
179    }
180
181    /// Consume `self`, returning the inner body
182    pub(crate) fn into_inner(self) -> B {
183        self.body
184    }
185}
186
187impl<B> Stream for BodyIntoStream<B>
188where
189    B: Body,
190{
191    type Item = Result<B::Data, B::Error>;
192
193    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
194        loop {
195            let this = self.as_mut().project();
196
197            if *this.yielded_all_data {
198                return Poll::Ready(None);
199            }
200
201            match std::task::ready!(this.body.poll_frame(cx)) {
202                Some(Ok(frame)) => match frame.into_data() {
203                    Ok(data) => return Poll::Ready(Some(Ok(data))),
204                    Err(frame) => {
205                        *this.yielded_all_data = true;
206                        *this.non_data_frame = Some(frame);
207                    }
208                },
209                Some(Err(err)) => return Poll::Ready(Some(Err(err))),
210                None => {
211                    *this.yielded_all_data = true;
212                }
213            }
214        }
215    }
216}
217
218impl<B> Body for BodyIntoStream<B>
219where
220    B: Body,
221{
222    type Data = B::Data;
223    type Error = B::Error;
224
225    fn poll_frame(
226        mut self: Pin<&mut Self>,
227        cx: &mut Context<'_>,
228    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
229        // First drive the stream impl. This consumes all data frames and buffer at most one
230        // trailers frame.
231        if let Some(frame) = std::task::ready!(self.as_mut().poll_next(cx)) {
232            return Poll::Ready(Some(frame.map(Frame::data)));
233        }
234
235        let this = self.project();
236
237        // Yield the trailers frame `poll_next` hit.
238        if let Some(frame) = this.non_data_frame.take() {
239            return Poll::Ready(Some(Ok(frame)));
240        }
241
242        // Yield any remaining frames in the body. There shouldn't be any after the trailers but
243        // you never know.
244        this.body.poll_frame(cx)
245    }
246
247    #[inline]
248    fn size_hint(&self) -> rama_http_types::dep::http_body::SizeHint {
249        self.body.size_hint()
250    }
251}
252
253pin_project! {
254    pub(crate) struct StreamErrorIntoIoError<S, E> {
255        #[pin]
256        inner: S,
257        error: Option<E>,
258    }
259}
260
261impl<S, E> StreamErrorIntoIoError<S, E> {
262    pub(crate) fn new(inner: S) -> Self {
263        Self { inner, error: None }
264    }
265
266    /// Get a pinned mutable reference to the inner inner
267    pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
268        self.project().inner
269    }
270}
271
272impl<S, T, E> Stream for StreamErrorIntoIoError<S, E>
273where
274    S: Stream<Item = Result<T, E>>,
275{
276    type Item = Result<T, io::Error>;
277
278    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
279        let this = self.project();
280        match ready!(this.inner.poll_next(cx)) {
281            None => Poll::Ready(None),
282            Some(Ok(value)) => Poll::Ready(Some(Ok(value))),
283            Some(Err(err)) => {
284                *this.error = Some(err);
285                Poll::Ready(Some(Err(io::Error::from_raw_os_error(SENTINEL_ERROR_CODE))))
286            }
287        }
288    }
289}
290
291pub(crate) const SENTINEL_ERROR_CODE: i32 = -837459418;
292
293/// Level of compression data should be compressed with.
294#[non_exhaustive]
295#[derive(Default, Clone, Copy, Debug, Eq, PartialEq, Hash)]
296pub enum CompressionLevel {
297    /// Fastest quality of compression, usually produces bigger size.
298    Fastest,
299    /// Best quality of compression, usually produces the smallest size.
300    Best,
301    /// Default quality of compression defined by the selected compression
302    /// algorithm.
303    #[default]
304    Default,
305    /// Precise quality based on the underlying compression algorithms'
306    /// qualities.
307    ///
308    /// The interpretation of this depends on the algorithm chosen and the
309    /// specific implementation backing it.
310    ///
311    /// Qualities are implicitly clamped to the algorithm's maximum.
312    Precise(u32),
313}
314
315use async_compression::Level as AsyncCompressionLevel;
316
317impl CompressionLevel {
318    #[allow(dead_code)]
319    pub(crate) fn into_async_compression(self) -> AsyncCompressionLevel {
320        match self {
321            CompressionLevel::Fastest => AsyncCompressionLevel::Fastest,
322            CompressionLevel::Best => AsyncCompressionLevel::Best,
323            CompressionLevel::Default => AsyncCompressionLevel::Default,
324            CompressionLevel::Precise(quality) => {
325                AsyncCompressionLevel::Precise(quality.try_into().unwrap_or(i32::MAX))
326            }
327        }
328    }
329}