rama_http/layer/util/
compression.rs1use crate::dep::http_body::{Body, Frame};
4use bytes::{Buf, Bytes, BytesMut};
5use futures_lite::Stream;
6use futures_lite::ready;
7use pin_project_lite::pin_project;
8use rama_core::error::BoxError;
9use std::{
10 io,
11 pin::Pin,
12 task::{Context, Poll},
13};
14use tokio::io::AsyncRead;
15use tokio_util::io::StreamReader;
16
17pub(crate) type AsyncReadBody<B> =
19 StreamReader<StreamErrorIntoIoError<BodyIntoStream<B>, <B as Body>::Error>, <B as Body>::Data>;
20
21pub(crate) trait DecorateAsyncRead {
23 type Input: AsyncRead;
24 type Output: AsyncRead;
25
26 fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output;
28
29 fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input>;
33}
34
35pin_project! {
36 pub(crate) struct WrapBody<M: DecorateAsyncRead> {
38 #[pin]
39 pub read: M::Output,
42 buf: BytesMut,
45 read_all_data: bool,
46 }
47}
48
49impl<M: DecorateAsyncRead> WrapBody<M> {
50 const INTERNAL_BUF_CAPACITY: usize = 4096;
51}
52
53impl<M: DecorateAsyncRead> WrapBody<M> {
54 #[allow(dead_code)]
55 pub(crate) fn new<B>(body: B, quality: CompressionLevel) -> Self
56 where
57 B: Body,
58 M: DecorateAsyncRead<Input = AsyncReadBody<B>>,
59 {
60 let stream = BodyIntoStream::new(body);
62
63 let stream = StreamErrorIntoIoError::<_, B::Error>::new(stream);
66
67 let read = StreamReader::new(stream);
69
70 let read = M::apply(read, quality);
72
73 Self {
74 read,
75 buf: BytesMut::with_capacity(Self::INTERNAL_BUF_CAPACITY),
76 read_all_data: false,
77 }
78 }
79}
80
81impl<B, M> Body for WrapBody<M>
82where
83 B: Body<Error: Into<BoxError>>,
84 M: DecorateAsyncRead<Input = AsyncReadBody<B>>,
85{
86 type Data = Bytes;
87 type Error = BoxError;
88
89 fn poll_frame(
90 self: Pin<&mut Self>,
91 cx: &mut Context<'_>,
92 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
93 let mut this = self.project();
94
95 if !*this.read_all_data {
96 if this.buf.capacity() == 0 {
97 this.buf.reserve(Self::INTERNAL_BUF_CAPACITY);
98 }
99
100 let result = tokio_util::io::poll_read_buf(this.read.as_mut(), cx, &mut this.buf);
101
102 match ready!(result) {
103 Ok(0) => {
104 *this.read_all_data = true;
105 }
106 Ok(_) => {
107 let chunk = this.buf.split().freeze();
108 return Poll::Ready(Some(Ok(Frame::data(chunk))));
109 }
110 Err(err) => {
111 let body_error: Option<B::Error> = M::get_pin_mut(this.read)
112 .get_pin_mut()
113 .project()
114 .error
115 .take();
116
117 if let Some(body_error) = body_error {
118 return Poll::Ready(Some(Err(body_error.into())));
119 } else if err.raw_os_error() == Some(SENTINEL_ERROR_CODE) {
120 unreachable!()
123 } else {
124 return Poll::Ready(Some(Err(err.into())));
125 }
126 }
127 }
128 }
129 let body = M::get_pin_mut(this.read).get_pin_mut().get_pin_mut();
131 body.poll_frame(cx).map(|option| {
132 option.map(|result| {
133 result
134 .map(|frame| frame.map_data(|mut data| data.copy_to_bytes(data.remaining())))
135 .map_err(|err| err.into())
136 })
137 })
138 }
139}
140
141pin_project! {
142 pub(crate) struct BodyIntoStream<B>
143 where
144 B: Body,
145 {
146 #[pin]
147 body: B,
148 yielded_all_data: bool,
149 non_data_frame: Option<Frame<B::Data>>,
150 }
151}
152
153#[allow(dead_code)]
154impl<B> BodyIntoStream<B>
155where
156 B: Body,
157{
158 pub(crate) fn new(body: B) -> Self {
159 Self {
160 body,
161 yielded_all_data: false,
162 non_data_frame: None,
163 }
164 }
165
166 pub(crate) fn get_ref(&self) -> &B {
168 &self.body
169 }
170
171 pub(crate) fn get_mut(&mut self) -> &mut B {
173 &mut self.body
174 }
175
176 pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut B> {
178 self.project().body
179 }
180
181 pub(crate) fn into_inner(self) -> B {
183 self.body
184 }
185}
186
187impl<B> Stream for BodyIntoStream<B>
188where
189 B: Body,
190{
191 type Item = Result<B::Data, B::Error>;
192
193 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
194 loop {
195 let this = self.as_mut().project();
196
197 if *this.yielded_all_data {
198 return Poll::Ready(None);
199 }
200
201 match std::task::ready!(this.body.poll_frame(cx)) {
202 Some(Ok(frame)) => match frame.into_data() {
203 Ok(data) => return Poll::Ready(Some(Ok(data))),
204 Err(frame) => {
205 *this.yielded_all_data = true;
206 *this.non_data_frame = Some(frame);
207 }
208 },
209 Some(Err(err)) => return Poll::Ready(Some(Err(err))),
210 None => {
211 *this.yielded_all_data = true;
212 }
213 }
214 }
215 }
216}
217
218impl<B> Body for BodyIntoStream<B>
219where
220 B: Body,
221{
222 type Data = B::Data;
223 type Error = B::Error;
224
225 fn poll_frame(
226 mut self: Pin<&mut Self>,
227 cx: &mut Context<'_>,
228 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
229 if let Some(frame) = std::task::ready!(self.as_mut().poll_next(cx)) {
232 return Poll::Ready(Some(frame.map(Frame::data)));
233 }
234
235 let this = self.project();
236
237 if let Some(frame) = this.non_data_frame.take() {
239 return Poll::Ready(Some(Ok(frame)));
240 }
241
242 this.body.poll_frame(cx)
245 }
246
247 #[inline]
248 fn size_hint(&self) -> rama_http_types::dep::http_body::SizeHint {
249 self.body.size_hint()
250 }
251}
252
253pin_project! {
254 pub(crate) struct StreamErrorIntoIoError<S, E> {
255 #[pin]
256 inner: S,
257 error: Option<E>,
258 }
259}
260
261impl<S, E> StreamErrorIntoIoError<S, E> {
262 pub(crate) fn new(inner: S) -> Self {
263 Self { inner, error: None }
264 }
265
266 pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
268 self.project().inner
269 }
270}
271
272impl<S, T, E> Stream for StreamErrorIntoIoError<S, E>
273where
274 S: Stream<Item = Result<T, E>>,
275{
276 type Item = Result<T, io::Error>;
277
278 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
279 let this = self.project();
280 match ready!(this.inner.poll_next(cx)) {
281 None => Poll::Ready(None),
282 Some(Ok(value)) => Poll::Ready(Some(Ok(value))),
283 Some(Err(err)) => {
284 *this.error = Some(err);
285 Poll::Ready(Some(Err(io::Error::from_raw_os_error(SENTINEL_ERROR_CODE))))
286 }
287 }
288 }
289}
290
291pub(crate) const SENTINEL_ERROR_CODE: i32 = -837459418;
292
293#[non_exhaustive]
295#[derive(Default, Clone, Copy, Debug, Eq, PartialEq, Hash)]
296pub enum CompressionLevel {
297 Fastest,
299 Best,
301 #[default]
304 Default,
305 Precise(u32),
313}
314
315use async_compression::Level as AsyncCompressionLevel;
316
317impl CompressionLevel {
318 #[allow(dead_code)]
319 pub(crate) fn into_async_compression(self) -> AsyncCompressionLevel {
320 match self {
321 CompressionLevel::Fastest => AsyncCompressionLevel::Fastest,
322 CompressionLevel::Best => AsyncCompressionLevel::Best,
323 CompressionLevel::Default => AsyncCompressionLevel::Default,
324 CompressionLevel::Precise(quality) => {
325 AsyncCompressionLevel::Precise(quality.try_into().unwrap_or(i32::MAX))
326 }
327 }
328 }
329}