http_cache_stream/
body.rs

1//! Implementation of a HTTP body.
2
3use std::io;
4use std::path::Path;
5use std::pin::Pin;
6use std::task::Context;
7use std::task::Poll;
8use std::task::ready;
9
10use anyhow::Context as _;
11use anyhow::Result;
12use blake3::Hasher;
13use bytes::Bytes;
14use bytes::BytesMut;
15use futures::Stream;
16use futures::future::BoxFuture;
17use http_body::Frame;
18use pin_project_lite::pin_project;
19use runtime::AsyncWrite;
20use tempfile::NamedTempFile;
21use tempfile::TempPath;
22
23use crate::HttpBody;
24use crate::runtime;
25
26/// The default capacity for reading from files.
27const DEFAULT_CAPACITY: usize = 4096;
28
29pin_project! {
30    /// Represents the state machine of a caching upstream source.
31    #[project = ProjectedCachingUpstreamSourceState]
32    enum CachingUpstreamSourceState<B> {
33        /// The upstream body is being read.
34        ReadingUpstream {
35            // The upstream response body.
36            #[pin]
37            upstream: B,
38            // The writer for the cache file.
39            #[pin]
40            writer: Option<runtime::BufWriter<runtime::File>>,
41            // The temporary path of the cache file.
42            path: Option<TempPath>,
43            // The current bytes read from the upstream body.
44            current: Bytes,
45            // The hasher used to hash the body.
46            hasher: Hasher,
47            // The callback to invoke once the cache file is completed.
48            callback: Option<Box<dyn FnOnce(String, TempPath) -> BoxFuture<'static, Result<()>> + Send>>,
49        },
50        /// The cache file is being flushed.
51        FlushingFile {
52            // The writer for the cache file.
53            #[pin]
54            writer: Option<runtime::BufWriter<runtime::File>>,
55            // The temporary path of the cache file.
56            path: Option<TempPath>,
57            // The digest of the response body.
58            digest: String,
59            // The callback to invoke once the cache file is completed.
60            callback: Option<Box<dyn FnOnce(String, TempPath) -> BoxFuture<'static, Result<()>> + Send>>,
61        },
62        /// The callback is being invoked.
63        InvokingCallback {
64            #[pin]
65            future: BoxFuture<'static, Result<()>>,
66        },
67        /// The stream has completed.
68        Completed
69    }
70}
71
72pin_project! {
73    /// Represents a body source from an upstream body that is being cached.
74    struct CachingUpstreamSource<B> {
75        // The state of the stream.
76        #[pin]
77        state: CachingUpstreamSourceState<B>,
78    }
79}
80
81impl<B> CachingUpstreamSource<B> {
82    /// Creates a new body source for caching an upstream response.
83    ///
84    /// The callback is invoked after the body has been written to the cache.
85    async fn new<F>(upstream: B, temp_dir: &Path, callback: F) -> Result<Self>
86    where
87        F: FnOnce(String, TempPath) -> BoxFuture<'static, Result<()>> + Send + 'static,
88    {
89        let path = NamedTempFile::new_in(temp_dir)
90            .context("failed to create temporary body file for cache storage")?
91            .into_temp_path();
92
93        let file = runtime::File::create(&*path).await.with_context(|| {
94            format!(
95                "failed to create temporary body file `{path}`",
96                path = path.display()
97            )
98        })?;
99
100        Ok(Self {
101            state: CachingUpstreamSourceState::ReadingUpstream {
102                upstream,
103                writer: Some(runtime::BufWriter::new(file)),
104                path: Some(path),
105                callback: Some(Box::new(callback)),
106                current: Bytes::new(),
107                hasher: Hasher::new(),
108            },
109        })
110    }
111}
112
113impl<B> Stream for CachingUpstreamSource<B>
114where
115    B: HttpBody,
116{
117    type Item = io::Result<Bytes>;
118
119    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
120        loop {
121            let this = self.as_mut().project();
122            match this.state.project() {
123                ProjectedCachingUpstreamSourceState::ReadingUpstream {
124                    upstream,
125                    mut writer,
126                    path,
127                    current,
128                    hasher,
129                    callback,
130                } => {
131                    // Check to see if a read is needed
132                    if current.is_empty() {
133                        match ready!(upstream.poll_next_data(cx)) {
134                            Some(Ok(data)) if data.is_empty() => continue,
135                            Some(Ok(data)) => {
136                                // Update the hasher with the data that was read
137                                hasher.update(&data);
138                                *current = data;
139                            }
140                            Some(Err(e)) => {
141                                // Set state to finished and return
142                                self.set(Self {
143                                    state: CachingUpstreamSourceState::Completed,
144                                });
145                                return Poll::Ready(Some(Err(e)));
146                            }
147                            None => {
148                                let writer = writer.take();
149                                let path = path.take();
150                                let digest = hex::encode(hasher.finalize().as_bytes());
151                                let callback = callback.take();
152
153                                // We're done reading from upstream, transition to the flushing
154                                // state
155                                self.set(Self {
156                                    state: CachingUpstreamSourceState::FlushingFile {
157                                        writer,
158                                        path,
159                                        digest,
160                                        callback,
161                                    },
162                                });
163                                continue;
164                            }
165                        }
166                    }
167
168                    // Write the data to the cache and return it to the caller
169                    let mut data = current.clone();
170                    return match ready!(writer.as_pin_mut().unwrap().poll_write(cx, &data)) {
171                        Ok(n) => {
172                            *current = data.split_off(n);
173                            Poll::Ready(Some(Ok(data)))
174                        }
175                        Err(e) => {
176                            self.set(Self {
177                                state: CachingUpstreamSourceState::Completed,
178                            });
179                            Poll::Ready(Some(Err(e)))
180                        }
181                    };
182                }
183                ProjectedCachingUpstreamSourceState::FlushingFile {
184                    mut writer,
185                    path,
186                    digest,
187                    callback,
188                } => {
189                    // Attempt to poll the writer for flush
190                    match ready!(writer.as_mut().as_pin_mut().unwrap().poll_flush(cx)) {
191                        Ok(_) => {
192                            drop(writer.take());
193                            let path = path.take().unwrap();
194                            let digest = std::mem::take(digest);
195                            let callback = callback.take().unwrap();
196
197                            // Invoke the callback and transition to the invoking callback state
198                            let future = callback(digest, path);
199                            self.set(Self {
200                                state: CachingUpstreamSourceState::InvokingCallback { future },
201                            });
202                            continue;
203                        }
204                        Err(e) => {
205                            self.set(Self {
206                                state: CachingUpstreamSourceState::Completed,
207                            });
208                            return Poll::Ready(Some(Err(e)));
209                        }
210                    }
211                }
212                ProjectedCachingUpstreamSourceState::InvokingCallback { future } => {
213                    return match ready!(future.poll(cx)) {
214                        Ok(_) => {
215                            self.set(Self {
216                                state: CachingUpstreamSourceState::Completed,
217                            });
218                            Poll::Ready(None)
219                        }
220                        Err(e) => {
221                            self.set(Self {
222                                state: CachingUpstreamSourceState::Completed,
223                            });
224                            Poll::Ready(Some(Err(io::Error::other(e))))
225                        }
226                    };
227                }
228                ProjectedCachingUpstreamSourceState::Completed => return Poll::Ready(None),
229            }
230        }
231    }
232}
233
234pin_project! {
235    /// Represents a body source from a previously cached response body file.
236    struct FileSource {
237        // The cache file being read.
238        #[pin]
239        reader: runtime::BufReader<runtime::File>,
240        // The length of the file.
241        len: u64,
242        // The current read buffer.
243        buf: BytesMut,
244        // Whether or not we've finished the stream.
245        finished: bool,
246    }
247}
248
249impl Stream for FileSource {
250    type Item = io::Result<Bytes>;
251
252    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
253        let this = self.project();
254
255        if *this.finished {
256            return Poll::Ready(None);
257        }
258
259        if this.buf.capacity() == 0 {
260            this.buf.reserve(DEFAULT_CAPACITY);
261        }
262
263        cfg_if::cfg_if! {
264            if #[cfg(feature = "tokio")] {
265                match ready!(tokio_util::io::poll_read_buf(this.reader, cx, this.buf)) {
266                    Ok(0) => {
267                        *this.finished = true;
268                        Poll::Ready(None)
269                    }
270                    Ok(_) => {
271                        let chunk = this.buf.split();
272                        Poll::Ready(Some(Ok(chunk.freeze())))
273                    }
274                    Err(err) => {
275                        *this.finished = true;
276                        Poll::Ready(Some(Err(err)))
277                    }
278                }
279            } else if #[cfg(feature = "smol")] {
280                use futures::AsyncRead;
281                use bytes::BufMut;
282
283                if !this.buf.has_remaining_mut() {
284                    *this.finished = true;
285                    return Poll::Ready(None);
286                }
287
288                let chunk = this.buf.chunk_mut();
289                // SAFETY: `from_raw_parts_mut` will return a mutable slice treating the memory
290                //         as initialized despite itself being uninitialized.
291                //
292                //         However, we are only using the slice as a read buffer, so the
293                //         uninitialized content of the slice is not actually read.
294                //
295                //         Finally, upon a successful read, we advance the mutable buffer by the
296                //         number of bytes read so that the remaining uninitialized content of
297                //         the buffer will remain for the next poll.
298                let slice =
299                    unsafe { std::slice::from_raw_parts_mut(chunk.as_mut_ptr(), chunk.len()) };
300                match ready!(this.reader.poll_read(cx, slice)) {
301                    Ok(0) => {
302                        *this.finished = true;
303                        Poll::Ready(None)
304                    }
305                    Ok(n) => {
306                        unsafe {
307                            this.buf.advance_mut(n);
308                        }
309                        Poll::Ready(Some(Ok(this.buf.split().freeze())))
310                    }
311                    Err(e) => {
312                        *this.finished = true;
313                        Poll::Ready(Some(Err(e)))
314                    }
315                }
316            } else {
317                unimplemented!()
318            }
319        }
320    }
321}
322
323pin_project! {
324    /// Represents a response body source.
325    ///
326    /// The body may come from the following sources:
327    ///
328    /// * Upstream without caching the response body.
329    /// * Upstream with caching the response body.
330    /// * A previously cached response body file.
331    #[project = ProjectedBodySource]
332    enum BodySource<B> {
333        /// The body is coming from upstream without being cached.
334        Upstream {
335            // The underlying source for the body.
336            #[pin]
337            source: B
338        },
339        /// The body is coming from upstream with being cached.
340        CachingUpstream {
341            // The underlying source for the body.
342            #[pin]
343            source: CachingUpstreamSource<B>,
344        },
345        /// The body is coming from a previously cached response body.
346        File {
347            // The underlying source for the body.
348            #[pin]
349            source: FileSource
350        },
351    }
352}
353
354pin_project! {
355    /// Represents a response body.
356    pub struct Body<B> {
357        // The body source.
358        #[pin]
359        source: BodySource<B>
360    }
361}
362
363impl<B> Body<B>
364where
365    B: HttpBody,
366{
367    /// Constructs a new body from an upstream response body that is not being
368    /// cached.
369    pub(crate) fn from_upstream(upstream: B) -> Self {
370        Self {
371            source: BodySource::Upstream { source: upstream },
372        }
373    }
374
375    /// Constructs a new body from an upstream response body that is being
376    /// cached.
377    pub(crate) async fn from_caching_upstream<F>(
378        upstream: B,
379        temp_dir: &Path,
380        callback: F,
381    ) -> Result<Self>
382    where
383        F: FnOnce(String, TempPath) -> BoxFuture<'static, Result<()>> + Send + 'static,
384    {
385        Ok(Self {
386            source: BodySource::CachingUpstream {
387                source: CachingUpstreamSource::new(upstream, temp_dir, callback).await?,
388            },
389        })
390    }
391
392    /// Constructs a new body from a local file.
393    pub(crate) async fn from_file(file: runtime::File) -> Result<Self> {
394        let metadata = file.metadata().await?;
395
396        Ok(Self {
397            source: BodySource::File {
398                source: FileSource {
399                    reader: runtime::BufReader::new(file),
400                    len: metadata.len(),
401                    buf: BytesMut::new(),
402                    finished: false,
403                },
404            },
405        })
406    }
407}
408
409impl<B> http_body::Body for Body<B>
410where
411    B: HttpBody,
412{
413    type Data = Bytes;
414    type Error = io::Error;
415
416    fn poll_frame(
417        self: Pin<&mut Self>,
418        cx: &mut Context<'_>,
419    ) -> Poll<Option<Result<Frame<Self::Data>, io::Error>>> {
420        match self.project().source.project() {
421            ProjectedBodySource::Upstream { source } => source.poll_frame(cx),
422            ProjectedBodySource::CachingUpstream { source } => {
423                source.poll_next(cx).map_ok(Frame::data)
424            }
425            ProjectedBodySource::File { source } => source.poll_next(cx).map_ok(Frame::data),
426        }
427    }
428
429    fn is_end_stream(&self) -> bool {
430        match &self.source {
431            BodySource::Upstream { source } => source.is_end_stream(),
432            BodySource::CachingUpstream { source } => {
433                matches!(&source.state, CachingUpstreamSourceState::Completed)
434            }
435            BodySource::File { source } => source.finished,
436        }
437    }
438
439    fn size_hint(&self) -> http_body::SizeHint {
440        match &self.source {
441            BodySource::Upstream { source } => source.size_hint(),
442            BodySource::CachingUpstream { source } => match &source.state {
443                CachingUpstreamSourceState::ReadingUpstream { upstream, .. } => {
444                    upstream.size_hint()
445                }
446                _ => http_body::SizeHint::default(),
447            },
448            BodySource::File { source } => http_body::SizeHint::with_exact(source.len),
449        }
450    }
451}
452
453impl<B> HttpBody for Body<B> where B: HttpBody + Send {}
454
455/// An implementation of `Stream` for body.
456///
457/// This implementation only retrieves the data frames of the body.
458///
459/// Trailer frames are not read.
460impl<B> Stream for Body<B>
461where
462    B: HttpBody,
463{
464    type Item = io::Result<Bytes>;
465
466    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
467        match self.project().source.project() {
468            ProjectedBodySource::Upstream { source } => source.poll_next_data(cx),
469            ProjectedBodySource::CachingUpstream { source } => source.poll_next(cx),
470            ProjectedBodySource::File { source } => source.poll_next(cx),
471        }
472    }
473}