http_cache_stream/
body.rs

1//! Implementation of a HTTP body.
2
3use std::io;
4use std::path::Path;
5use std::pin::Pin;
6use std::task::Context;
7use std::task::Poll;
8use std::task::ready;
9
10use anyhow::Context as _;
11use anyhow::Result;
12use blake3::Hasher;
13use bytes::Bytes;
14use bytes::BytesMut;
15use futures::Stream;
16use futures::future::BoxFuture;
17use http_body::Body;
18use http_body::Frame;
19use http_body_util::BodyStream;
20use pin_project_lite::pin_project;
21use runtime::AsyncWrite;
22use tempfile::NamedTempFile;
23use tempfile::TempPath;
24
25use crate::runtime;
26
27/// The default capacity for reading from files.
28const DEFAULT_CAPACITY: usize = 4096;
29
30pin_project! {
31    /// Represents the state machine of a caching upstream source.
32    #[project = ProjectedCachingUpstreamSourceState]
33    enum CachingUpstreamSourceState<B> {
34        /// The upstream body is being read.
35        ReadingUpstream {
36            // The upstream response body.
37            #[pin]
38            upstream: BodyStream<B>,
39            // The writer for the cache file.
40            #[pin]
41            writer: Option<runtime::BufWriter<runtime::File>>,
42            // The temporary path of the cache file.
43            path: Option<TempPath>,
44            // The current bytes read from the upstream body.
45            current: Bytes,
46            // The hasher used to hash the body.
47            hasher: Hasher,
48            // The callback to invoke once the cache file is completed.
49            callback: Option<Box<dyn FnOnce(String, TempPath) -> BoxFuture<'static, Result<()>> + Send>>,
50        },
51        /// The cache file is being flushed.
52        FlushingFile {
53            // The writer for the cache file.
54            #[pin]
55            writer: Option<runtime::BufWriter<runtime::File>>,
56            // The temporary path of the cache file.
57            path: Option<TempPath>,
58            // The digest of the response body.
59            digest: String,
60            // The callback to invoke once the cache file is completed.
61            callback: Option<Box<dyn FnOnce(String, TempPath) -> BoxFuture<'static, Result<()>> + Send>>,
62        },
63        /// The callback is being invoked.
64        InvokingCallback {
65            #[pin]
66            future: BoxFuture<'static, Result<()>>,
67        },
68        /// The stream has completed.
69        Completed
70    }
71}
72
73pin_project! {
74    /// Represents a body source from an upstream body that is being cached.
75    struct CachingUpstreamSource<B> {
76        // The state of the stream.
77        #[pin]
78        state: CachingUpstreamSourceState<B>,
79    }
80}
81
82impl<B> CachingUpstreamSource<B> {
83    /// Creates a new body source for caching an upstream response.
84    ///
85    /// The callback is invoked after the body has been written to the cache.
86    async fn new<F>(upstream: B, temp_dir: &Path, callback: F) -> Result<Self>
87    where
88        F: FnOnce(String, TempPath) -> BoxFuture<'static, Result<()>> + Send + 'static,
89    {
90        let path = NamedTempFile::new_in(temp_dir)
91            .context("failed to create temporary body file for cache storage")?
92            .into_temp_path();
93
94        let file = runtime::File::create(&*path).await.with_context(|| {
95            format!(
96                "failed to create temporary body file `{path}`",
97                path = path.display()
98            )
99        })?;
100
101        Ok(Self {
102            state: CachingUpstreamSourceState::ReadingUpstream {
103                upstream: BodyStream::new(upstream),
104                writer: Some(runtime::BufWriter::new(file)),
105                path: Some(path),
106                callback: Some(Box::new(callback)),
107                current: Bytes::new(),
108                hasher: Hasher::new(),
109            },
110        })
111    }
112}
113
114impl<B> Body for CachingUpstreamSource<B>
115where
116    B: Body,
117    B::Data: Into<Bytes>,
118    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
119{
120    type Data = Bytes;
121    type Error = Box<dyn std::error::Error + Send + Sync>;
122
123    fn poll_frame(
124        mut self: Pin<&mut Self>,
125        cx: &mut Context<'_>,
126    ) -> Poll<Option<std::result::Result<Frame<Self::Data>, Self::Error>>> {
127        loop {
128            let this = self.as_mut().project();
129            match this.state.project() {
130                ProjectedCachingUpstreamSourceState::ReadingUpstream {
131                    upstream,
132                    mut writer,
133                    path,
134                    current,
135                    hasher,
136                    callback,
137                } => {
138                    // Check to see if a read is needed
139                    if current.is_empty() {
140                        match ready!(upstream.poll_next(cx)) {
141                            Some(Ok(frame)) => {
142                                let frame = frame.map_data(Into::into);
143                                match frame.into_data() {
144                                    Ok(data) if !data.is_empty() => {
145                                        // Update the hasher with the data that was read
146                                        hasher.update(&data);
147                                        *current = data;
148                                    }
149                                    Ok(_) => continue,
150                                    Err(frame) => return Poll::Ready(Some(Ok(frame))),
151                                }
152                            }
153                            Some(Err(e)) => {
154                                // Set state to finished and return
155                                self.set(Self {
156                                    state: CachingUpstreamSourceState::Completed,
157                                });
158                                return Poll::Ready(Some(Err(e.into())));
159                            }
160                            None => {
161                                let writer = writer.take();
162                                let path = path.take();
163                                let digest = hex::encode(hasher.finalize().as_bytes());
164                                let callback = callback.take();
165
166                                // We're done reading from upstream, transition to the flushing
167                                // state
168                                self.set(Self {
169                                    state: CachingUpstreamSourceState::FlushingFile {
170                                        writer,
171                                        path,
172                                        digest,
173                                        callback,
174                                    },
175                                });
176                                continue;
177                            }
178                        }
179                    }
180
181                    // Write the data to the cache and return it to the caller
182                    let mut data = current.clone();
183                    return match ready!(writer.as_pin_mut().unwrap().poll_write(cx, &data)) {
184                        Ok(n) => {
185                            *current = data.split_off(n);
186                            Poll::Ready(Some(Ok(Frame::data(data))))
187                        }
188                        Err(e) => {
189                            self.set(Self {
190                                state: CachingUpstreamSourceState::Completed,
191                            });
192                            Poll::Ready(Some(Err(e.into())))
193                        }
194                    };
195                }
196                ProjectedCachingUpstreamSourceState::FlushingFile {
197                    mut writer,
198                    path,
199                    digest,
200                    callback,
201                } => {
202                    // Attempt to poll the writer for flush
203                    match ready!(writer.as_mut().as_pin_mut().unwrap().poll_flush(cx)) {
204                        Ok(_) => {
205                            drop(writer.take());
206                            let path = path.take().unwrap();
207                            let digest = std::mem::take(digest);
208                            let callback = callback.take().unwrap();
209
210                            // Invoke the callback and transition to the invoking callback state
211                            let future = callback(digest, path);
212                            self.set(Self {
213                                state: CachingUpstreamSourceState::InvokingCallback { future },
214                            });
215                            continue;
216                        }
217                        Err(e) => {
218                            self.set(Self {
219                                state: CachingUpstreamSourceState::Completed,
220                            });
221                            return Poll::Ready(Some(Err(e.into())));
222                        }
223                    }
224                }
225                ProjectedCachingUpstreamSourceState::InvokingCallback { future } => {
226                    return match ready!(future.poll(cx)) {
227                        Ok(_) => {
228                            self.set(Self {
229                                state: CachingUpstreamSourceState::Completed,
230                            });
231                            Poll::Ready(None)
232                        }
233                        Err(e) => {
234                            self.set(Self {
235                                state: CachingUpstreamSourceState::Completed,
236                            });
237                            Poll::Ready(Some(Err(e.into_boxed_dyn_error())))
238                        }
239                    };
240                }
241                ProjectedCachingUpstreamSourceState::Completed => return Poll::Ready(None),
242            }
243        }
244    }
245}
246
247pin_project! {
248    /// Represents a body source from a previously cached response body file.
249    struct FileSource {
250        // The cache file being read.
251        #[pin]
252        reader: runtime::BufReader<runtime::File>,
253        // The length of the file.
254        len: u64,
255        // The current read buffer.
256        buf: BytesMut,
257        // Whether or not we've finished the stream.
258        finished: bool,
259    }
260}
261
262impl Body for FileSource {
263    type Data = Bytes;
264    type Error = io::Error;
265
266    fn poll_frame(
267        self: Pin<&mut Self>,
268        cx: &mut Context<'_>,
269    ) -> Poll<Option<io::Result<Frame<Self::Data>>>> {
270        let this = self.project();
271
272        if *this.finished {
273            return Poll::Ready(None);
274        }
275
276        if this.buf.capacity() == 0 {
277            this.buf.reserve(DEFAULT_CAPACITY);
278        }
279
280        cfg_if::cfg_if! {
281            if #[cfg(feature = "tokio")] {
282                match ready!(tokio_util::io::poll_read_buf(this.reader, cx, this.buf)) {
283                    Ok(0) => {
284                        *this.finished = true;
285                        Poll::Ready(None)
286                    }
287                    Ok(_) => {
288                        let chunk = this.buf.split();
289                        Poll::Ready(Some(Ok(Frame::data(chunk.freeze()))))
290                    }
291                    Err(err) => {
292                        *this.finished = true;
293                        Poll::Ready(Some(Err(err)))
294                    }
295                }
296            } else if #[cfg(feature = "smol")] {
297                use futures::AsyncRead;
298                use bytes::BufMut;
299
300                if !this.buf.has_remaining_mut() {
301                    *this.finished = true;
302                    return Poll::Ready(None);
303                }
304
305                let chunk = this.buf.chunk_mut();
306                // SAFETY: `from_raw_parts_mut` will return a mutable slice treating the memory
307                //         as initialized despite itself being uninitialized.
308                //
309                //         However, we are only using the slice as a read buffer, so the
310                //         uninitialized content of the slice is not actually read.
311                //
312                //         Finally, upon a successful read, we advance the mutable buffer by the
313                //         number of bytes read so that the remaining uninitialized content of
314                //         the buffer will remain for the next poll.
315                let slice =
316                    unsafe { std::slice::from_raw_parts_mut(chunk.as_mut_ptr(), chunk.len()) };
317                match ready!(this.reader.poll_read(cx, slice)) {
318                    Ok(0) => {
319                        *this.finished = true;
320                        Poll::Ready(None)
321                    }
322                    Ok(n) => {
323                        unsafe {
324                            this.buf.advance_mut(n);
325                        }
326                        Poll::Ready(Some(Ok(Frame::data(this.buf.split().freeze()))))
327                    }
328                    Err(e) => {
329                        *this.finished = true;
330                        Poll::Ready(Some(Err(e)))
331                    }
332                }
333            } else {
334                unimplemented!()
335            }
336        }
337    }
338}
339
340pin_project! {
341    /// Represents a response body source.
342    ///
343    /// The body may come from the following sources:
344    ///
345    /// * Upstream without caching the response body.
346    /// * Upstream with caching the response body.
347    /// * A previously cached response body file.
348    #[project = ProjectedBodySource]
349    enum BodySource<B> {
350        /// The body is coming from upstream without being cached.
351        Upstream {
352            // The underlying source for the body.
353            #[pin]
354            source: BodyStream<B>
355        },
356        /// The body is coming from upstream with being cached.
357        CachingUpstream {
358            // The underlying source for the body.
359            #[pin]
360            source: CachingUpstreamSource<B>,
361        },
362        /// The body is coming from a previously cached response body.
363        File {
364            // The underlying source for the body.
365            #[pin]
366            source: FileSource
367        },
368    }
369}
370
371pin_project! {
372    /// Represents a cache body.
373    ///
374    /// The cache body may be sourced from an upstream response or from a file from the cache.
375    pub struct CacheBody<B> {
376        // The body source.
377        #[pin]
378        source: BodySource<B>
379    }
380}
381
382impl<B> CacheBody<B>
383where
384    B: Body,
385{
386    /// Constructs a new body from an upstream response body that is not being
387    /// cached.
388    pub(crate) fn from_upstream(upstream: B) -> Self {
389        Self {
390            source: BodySource::Upstream {
391                source: BodyStream::new(upstream),
392            },
393        }
394    }
395
396    /// Constructs a new body from an upstream response body that is being
397    /// cached.
398    pub(crate) async fn from_caching_upstream<F>(
399        upstream: B,
400        temp_dir: &Path,
401        callback: F,
402    ) -> Result<Self>
403    where
404        F: FnOnce(String, TempPath) -> BoxFuture<'static, Result<()>> + Send + 'static,
405    {
406        Ok(Self {
407            source: BodySource::CachingUpstream {
408                source: CachingUpstreamSource::new(upstream, temp_dir, callback).await?,
409            },
410        })
411    }
412
413    /// Constructs a new body from a local file.
414    pub(crate) async fn from_file(file: runtime::File) -> Result<Self> {
415        let metadata = file.metadata().await?;
416
417        Ok(Self {
418            source: BodySource::File {
419                source: FileSource {
420                    reader: runtime::BufReader::new(file),
421                    len: metadata.len(),
422                    buf: BytesMut::new(),
423                    finished: false,
424                },
425            },
426        })
427    }
428}
429
430impl<B> Body for CacheBody<B>
431where
432    B: Body,
433    B::Data: Into<Bytes>,
434    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
435{
436    type Data = Bytes;
437    type Error = Box<dyn std::error::Error + Send + Sync>;
438
439    fn poll_frame(
440        self: Pin<&mut Self>,
441        cx: &mut Context<'_>,
442    ) -> Poll<Option<std::result::Result<http_body::Frame<Self::Data>, Self::Error>>> {
443        match self.project().source.project() {
444            ProjectedBodySource::Upstream { source } => source
445                .poll_frame(cx)
446                .map_ok(|f| f.map_data(Into::into))
447                .map_err(Into::into),
448            ProjectedBodySource::CachingUpstream { source } => source.poll_frame(cx),
449            ProjectedBodySource::File { source } => source.poll_frame(cx).map_err(Into::into),
450        }
451    }
452
453    fn is_end_stream(&self) -> bool {
454        match &self.source {
455            BodySource::Upstream { source } => source.is_end_stream(),
456            BodySource::CachingUpstream { source } => {
457                matches!(&source.state, CachingUpstreamSourceState::Completed)
458            }
459            BodySource::File { source } => source.finished,
460        }
461    }
462
463    fn size_hint(&self) -> http_body::SizeHint {
464        match &self.source {
465            BodySource::Upstream { source } => Body::size_hint(source),
466            BodySource::CachingUpstream { source } => match &source.state {
467                CachingUpstreamSourceState::ReadingUpstream { upstream, .. } => {
468                    Body::size_hint(upstream)
469                }
470                _ => http_body::SizeHint::default(),
471            },
472            BodySource::File { source } => http_body::SizeHint::with_exact(source.len),
473        }
474    }
475}