http_cache_stream/
body.rs

1//! Implementation of a HTTP body.
2
3use std::io;
4use std::pin::Pin;
5use std::task::Context;
6use std::task::Poll;
7
8use anyhow::Result;
9use blake3::Hash;
10use blake3::Hasher;
11use bytes::Bytes;
12use bytes::BytesMut;
13use futures::Stream;
14use http_body::Frame;
15use pin_project_lite::pin_project;
16
17use crate::HttpBody;
18use crate::runtime;
19
20/// The default capacity for reading from files.
21const DEFAULT_CAPACITY: usize = 4096;
22
23pin_project! {
24    /// A wrapper around a byte stream that performs a Blake3 hash on the stream data.
25    struct HashStream<S> {
26        #[pin]
27        stream: S,
28        hasher: Hasher,
29        finished: bool,
30    }
31}
32
33impl<S> HashStream<S> {
34    /// Constructs a new hash stream.
35    fn new(stream: S) -> Self
36    where
37        S: Stream<Item = io::Result<Bytes>>,
38    {
39        Self {
40            stream,
41            hasher: Hasher::new(),
42            finished: false,
43        }
44    }
45
46    /// Computes the hash of the byte stream.
47    fn hash(self) -> Hash {
48        self.hasher.finalize()
49    }
50}
51
52impl<S> Stream for HashStream<S>
53where
54    S: Stream<Item = io::Result<Bytes>>,
55{
56    type Item = io::Result<Bytes>;
57
58    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
59        if self.finished {
60            return Poll::Ready(None);
61        }
62
63        let this = self.project();
64        match this.stream.poll_next(cx) {
65            Poll::Ready(Some(Ok(bytes))) => {
66                this.hasher.update(&bytes);
67                Poll::Ready(Some(Ok(bytes)))
68            }
69            Poll::Ready(Some(Err(e))) => {
70                *this.finished = true;
71                Poll::Ready(Some(Err(e)))
72            }
73            Poll::Ready(None) => {
74                *this.finished = true;
75                Poll::Ready(None)
76            }
77            Poll::Pending => Poll::Pending,
78        }
79    }
80}
81
82pin_project! {
83    /// Represents the source of a response body.
84    #[project = ProjectedSource]
85    enum Source<B> {
86        /// The body is coming from an upstream response.
87        Upstream {
88            #[pin]
89            body: B,
90        },
91        /// The body is coming from a local file.
92        File {
93            #[pin]
94            file: runtime::File,
95            len: u64,
96            buf: BytesMut,
97            finished: bool,
98        },
99    }
100}
101
102pin_project! {
103    /// Represents a response body from the HTTP cache.
104    ///
105    /// The response body may be from upstream or from a local file.
106    pub struct Body<B> {
107        #[pin]
108        source: Source<B>,
109    }
110}
111
112impl<B: HttpBody> Body<B> {
113    /// Constructs a new body from a local file.
114    pub(crate) async fn from_file(file: runtime::File) -> Result<Self> {
115        let metadata = file.metadata().await?;
116
117        Ok(Self {
118            source: Source::File {
119                file,
120                len: metadata.len(),
121                buf: BytesMut::new(),
122                finished: false,
123            },
124        })
125    }
126
127    /// Constructs a new body from an upstream response.
128    pub(crate) fn from_upstream(body: B) -> Self {
129        Self {
130            source: Source::Upstream { body },
131        }
132    }
133
134    /// Writes the body to the given file.
135    ///
136    /// Trailers in the body are not stored.
137    ///
138    /// Returns the calculated Blake3 hash of the body.
139    pub(crate) async fn write_to(self, file: &mut runtime::File) -> io::Result<String> {
140        cfg_if::cfg_if! {
141            if #[cfg(feature = "tokio")] {
142                let this = self;
143                tokio::pin!(this);
144
145                let mut stream = HashStream::new(this);
146                let mut reader = tokio_util::io::StreamReader::new(&mut stream);
147                tokio::io::copy(&mut reader, file).await?;
148                Ok(hex::encode(stream.hash().as_bytes()))
149            } else if #[cfg(feature = "smol")] {
150                use futures::stream::TryStreamExt;
151                let this = self;
152                futures::pin_mut!(this);
153
154                let mut stream = HashStream::new(this);
155                let mut reader = (&mut stream).into_async_read();
156                smol::io::copy(&mut reader, file).await?;
157                Ok(hex::encode(stream.hash().as_bytes()))
158            } else {
159                unimplemented!()
160            }
161        }
162    }
163}
164
165impl<B: HttpBody> http_body::Body for Body<B> {
166    type Data = Bytes;
167    type Error = io::Error;
168
169    fn poll_frame(
170        mut self: Pin<&mut Self>,
171        cx: &mut Context<'_>,
172    ) -> Poll<Option<Result<Frame<Self::Data>, io::Error>>> {
173        match self.as_mut().project().source.project() {
174            ProjectedSource::Upstream { body } => body.poll_frame(cx),
175            ProjectedSource::File {
176                file,
177                len: _,
178                buf,
179                finished,
180            } => {
181                if *finished {
182                    return Poll::Ready(None);
183                }
184
185                if buf.capacity() == 0 {
186                    buf.reserve(DEFAULT_CAPACITY);
187                }
188
189                cfg_if::cfg_if! {
190                    if #[cfg(feature = "tokio")] {
191                        match tokio_util::io::poll_read_buf(file, cx, buf) {
192                            Poll::Pending => Poll::Pending,
193                            Poll::Ready(Err(err)) => {
194                                *finished = true;
195                                Poll::Ready(Some(Err(err)))
196                            }
197                            Poll::Ready(Ok(0)) => {
198                                *finished = true;
199                                Poll::Ready(None)
200                            }
201                            Poll::Ready(Ok(_)) => {
202                                let chunk = buf.split();
203                                Poll::Ready(Some(Ok(Frame::data(chunk.freeze()))))
204                            }
205                        }
206                    } else if #[cfg(feature = "smol")] {
207                        use futures::AsyncRead;
208                        use bytes::BufMut;
209
210                        if !buf.has_remaining_mut() {
211                            *finished = true;
212                            return Poll::Ready(None);
213                        }
214
215                        let chunk = buf.chunk_mut();
216                        let slice =
217                            unsafe { std::slice::from_raw_parts_mut(chunk.as_mut_ptr(), chunk.len()) };
218                        match file.poll_read(cx, slice) {
219                            Poll::Ready(Ok(0)) => {
220                                *finished = true;
221                                Poll::Ready(None)
222                            }
223                            Poll::Ready(Ok(n)) => {
224                                unsafe {
225                                    buf.advance_mut(n);
226                                }
227                                Poll::Ready(Some(Ok(Frame::data(buf.split().freeze()))))
228                            }
229                            Poll::Ready(Err(e)) => {
230                                *finished = true;
231                                Poll::Ready(Some(Err(e)))
232                            }
233                            Poll::Pending => Poll::Pending,
234                        }
235                    } else {
236                        unimplemented!()
237                    }
238                }
239            }
240        }
241    }
242
243    fn is_end_stream(&self) -> bool {
244        match &self.source {
245            Source::Upstream { body } => body.is_end_stream(),
246            Source::File { finished, .. } => *finished,
247        }
248    }
249
250    fn size_hint(&self) -> http_body::SizeHint {
251        match &self.source {
252            Source::Upstream { body } => body.size_hint(),
253            Source::File { len, .. } => http_body::SizeHint::with_exact(*len),
254        }
255    }
256}
257
258impl<B: HttpBody> HttpBody for Body<B> {}
259
260/// An implementation of `Stream` for body.
261///
262/// This implementation only retrieves the data frames of the body.
263///
264/// Trailer frames are not read.
265impl<B: HttpBody> Stream for Body<B> {
266    type Item = io::Result<Bytes>;
267
268    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
269        use http_body::Body;
270
271        match self.poll_frame(cx) {
272            Poll::Ready(Some(Ok(frame))) => match frame.into_data().ok() {
273                Some(data) => Poll::Ready(Some(Ok(data))),
274                None => Poll::Ready(None),
275            },
276            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
277            Poll::Ready(None) => Poll::Ready(None),
278            Poll::Pending => Poll::Pending,
279        }
280    }
281}