1use std::io;
4use std::path::Path;
5use std::pin::Pin;
6use std::task::Context;
7use std::task::Poll;
8use std::task::ready;
9
10use anyhow::Context as _;
11use anyhow::Result;
12use blake3::Hasher;
13use bytes::Bytes;
14use bytes::BytesMut;
15use futures::Stream;
16use futures::future::BoxFuture;
17use http_body::Body;
18use http_body::Frame;
19use http_body_util::BodyStream;
20use pin_project_lite::pin_project;
21use runtime::AsyncWrite;
22use tempfile::NamedTempFile;
23use tempfile::TempPath;
24
25use crate::runtime;
26
27const DEFAULT_CAPACITY: usize = 4096;
29
30pin_project! {
31 #[project = ProjectedCachingUpstreamSourceState]
33 enum CachingUpstreamSourceState<B> {
34 ReadingUpstream {
36 #[pin]
38 upstream: BodyStream<B>,
39 #[pin]
41 writer: Option<runtime::BufWriter<runtime::File>>,
42 path: Option<TempPath>,
44 current: Bytes,
46 hasher: Hasher,
48 callback: Option<Box<dyn FnOnce(String, TempPath) -> BoxFuture<'static, Result<()>> + Send>>,
50 },
51 FlushingFile {
53 #[pin]
55 writer: Option<runtime::BufWriter<runtime::File>>,
56 path: Option<TempPath>,
58 digest: String,
60 callback: Option<Box<dyn FnOnce(String, TempPath) -> BoxFuture<'static, Result<()>> + Send>>,
62 },
63 InvokingCallback {
65 #[pin]
66 future: BoxFuture<'static, Result<()>>,
67 },
68 Completed
70 }
71}
72
73pin_project! {
74 struct CachingUpstreamSource<B> {
76 #[pin]
78 state: CachingUpstreamSourceState<B>,
79 }
80}
81
82impl<B> CachingUpstreamSource<B> {
83 async fn new<F>(upstream: B, temp_dir: &Path, callback: F) -> Result<Self>
87 where
88 F: FnOnce(String, TempPath) -> BoxFuture<'static, Result<()>> + Send + 'static,
89 {
90 let path = NamedTempFile::new_in(temp_dir)
91 .context("failed to create temporary body file for cache storage")?
92 .into_temp_path();
93
94 let file = runtime::File::create(&*path).await.with_context(|| {
95 format!(
96 "failed to create temporary body file `{path}`",
97 path = path.display()
98 )
99 })?;
100
101 Ok(Self {
102 state: CachingUpstreamSourceState::ReadingUpstream {
103 upstream: BodyStream::new(upstream),
104 writer: Some(runtime::BufWriter::new(file)),
105 path: Some(path),
106 callback: Some(Box::new(callback)),
107 current: Bytes::new(),
108 hasher: Hasher::new(),
109 },
110 })
111 }
112}
113
114impl<B> Body for CachingUpstreamSource<B>
115where
116 B: Body,
117 B::Data: Into<Bytes>,
118 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
119{
120 type Data = Bytes;
121 type Error = Box<dyn std::error::Error + Send + Sync>;
122
123 fn poll_frame(
124 mut self: Pin<&mut Self>,
125 cx: &mut Context<'_>,
126 ) -> Poll<Option<std::result::Result<Frame<Self::Data>, Self::Error>>> {
127 loop {
128 let this = self.as_mut().project();
129 match this.state.project() {
130 ProjectedCachingUpstreamSourceState::ReadingUpstream {
131 upstream,
132 mut writer,
133 path,
134 current,
135 hasher,
136 callback,
137 } => {
138 if current.is_empty() {
140 match ready!(upstream.poll_next(cx)) {
141 Some(Ok(frame)) => {
142 let frame = frame.map_data(Into::into);
143 match frame.into_data() {
144 Ok(data) if !data.is_empty() => {
145 hasher.update(&data);
147 *current = data;
148 }
149 Ok(_) => continue,
150 Err(frame) => return Poll::Ready(Some(Ok(frame))),
151 }
152 }
153 Some(Err(e)) => {
154 self.set(Self {
156 state: CachingUpstreamSourceState::Completed,
157 });
158 return Poll::Ready(Some(Err(e.into())));
159 }
160 None => {
161 let writer = writer.take();
162 let path = path.take();
163 let digest = hex::encode(hasher.finalize().as_bytes());
164 let callback = callback.take();
165
166 self.set(Self {
169 state: CachingUpstreamSourceState::FlushingFile {
170 writer,
171 path,
172 digest,
173 callback,
174 },
175 });
176 continue;
177 }
178 }
179 }
180
181 let mut data = current.clone();
183 return match ready!(writer.as_pin_mut().unwrap().poll_write(cx, &data)) {
184 Ok(n) => {
185 *current = data.split_off(n);
186 Poll::Ready(Some(Ok(Frame::data(data))))
187 }
188 Err(e) => {
189 self.set(Self {
190 state: CachingUpstreamSourceState::Completed,
191 });
192 Poll::Ready(Some(Err(e.into())))
193 }
194 };
195 }
196 ProjectedCachingUpstreamSourceState::FlushingFile {
197 mut writer,
198 path,
199 digest,
200 callback,
201 } => {
202 match ready!(writer.as_mut().as_pin_mut().unwrap().poll_flush(cx)) {
204 Ok(_) => {
205 drop(writer.take());
206 let path = path.take().unwrap();
207 let digest = std::mem::take(digest);
208 let callback = callback.take().unwrap();
209
210 let future = callback(digest, path);
212 self.set(Self {
213 state: CachingUpstreamSourceState::InvokingCallback { future },
214 });
215 continue;
216 }
217 Err(e) => {
218 self.set(Self {
219 state: CachingUpstreamSourceState::Completed,
220 });
221 return Poll::Ready(Some(Err(e.into())));
222 }
223 }
224 }
225 ProjectedCachingUpstreamSourceState::InvokingCallback { future } => {
226 return match ready!(future.poll(cx)) {
227 Ok(_) => {
228 self.set(Self {
229 state: CachingUpstreamSourceState::Completed,
230 });
231 Poll::Ready(None)
232 }
233 Err(e) => {
234 self.set(Self {
235 state: CachingUpstreamSourceState::Completed,
236 });
237 Poll::Ready(Some(Err(e.into_boxed_dyn_error())))
238 }
239 };
240 }
241 ProjectedCachingUpstreamSourceState::Completed => return Poll::Ready(None),
242 }
243 }
244 }
245}
246
247pin_project! {
248 struct FileSource {
250 #[pin]
252 reader: runtime::BufReader<runtime::File>,
253 len: u64,
255 buf: BytesMut,
257 finished: bool,
259 }
260}
261
262impl Body for FileSource {
263 type Data = Bytes;
264 type Error = io::Error;
265
266 fn poll_frame(
267 self: Pin<&mut Self>,
268 cx: &mut Context<'_>,
269 ) -> Poll<Option<io::Result<Frame<Self::Data>>>> {
270 let this = self.project();
271
272 if *this.finished {
273 return Poll::Ready(None);
274 }
275
276 if this.buf.capacity() == 0 {
277 this.buf.reserve(DEFAULT_CAPACITY);
278 }
279
280 cfg_if::cfg_if! {
281 if #[cfg(feature = "tokio")] {
282 match ready!(tokio_util::io::poll_read_buf(this.reader, cx, this.buf)) {
283 Ok(0) => {
284 *this.finished = true;
285 Poll::Ready(None)
286 }
287 Ok(_) => {
288 let chunk = this.buf.split();
289 Poll::Ready(Some(Ok(Frame::data(chunk.freeze()))))
290 }
291 Err(err) => {
292 *this.finished = true;
293 Poll::Ready(Some(Err(err)))
294 }
295 }
296 } else if #[cfg(feature = "smol")] {
297 use futures::AsyncRead;
298 use bytes::BufMut;
299
300 if !this.buf.has_remaining_mut() {
301 *this.finished = true;
302 return Poll::Ready(None);
303 }
304
305 let chunk = this.buf.chunk_mut();
306 let slice =
316 unsafe { std::slice::from_raw_parts_mut(chunk.as_mut_ptr(), chunk.len()) };
317 match ready!(this.reader.poll_read(cx, slice)) {
318 Ok(0) => {
319 *this.finished = true;
320 Poll::Ready(None)
321 }
322 Ok(n) => {
323 unsafe {
324 this.buf.advance_mut(n);
325 }
326 Poll::Ready(Some(Ok(Frame::data(this.buf.split().freeze()))))
327 }
328 Err(e) => {
329 *this.finished = true;
330 Poll::Ready(Some(Err(e)))
331 }
332 }
333 } else {
334 unimplemented!()
335 }
336 }
337 }
338}
339
340pin_project! {
341 #[project = ProjectedBodySource]
349 enum BodySource<B> {
350 Upstream {
352 #[pin]
354 source: BodyStream<B>
355 },
356 CachingUpstream {
358 #[pin]
360 source: CachingUpstreamSource<B>,
361 },
362 File {
364 #[pin]
366 source: FileSource
367 },
368 }
369}
370
371pin_project! {
372 pub struct CacheBody<B> {
376 #[pin]
378 source: BodySource<B>
379 }
380}
381
382impl<B> CacheBody<B>
383where
384 B: Body,
385{
386 pub(crate) fn from_upstream(upstream: B) -> Self {
389 Self {
390 source: BodySource::Upstream {
391 source: BodyStream::new(upstream),
392 },
393 }
394 }
395
396 pub(crate) async fn from_caching_upstream<F>(
399 upstream: B,
400 temp_dir: &Path,
401 callback: F,
402 ) -> Result<Self>
403 where
404 F: FnOnce(String, TempPath) -> BoxFuture<'static, Result<()>> + Send + 'static,
405 {
406 Ok(Self {
407 source: BodySource::CachingUpstream {
408 source: CachingUpstreamSource::new(upstream, temp_dir, callback).await?,
409 },
410 })
411 }
412
413 pub(crate) async fn from_file(file: runtime::File) -> Result<Self> {
415 let metadata = file.metadata().await?;
416
417 Ok(Self {
418 source: BodySource::File {
419 source: FileSource {
420 reader: runtime::BufReader::new(file),
421 len: metadata.len(),
422 buf: BytesMut::new(),
423 finished: false,
424 },
425 },
426 })
427 }
428}
429
430impl<B> Body for CacheBody<B>
431where
432 B: Body,
433 B::Data: Into<Bytes>,
434 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
435{
436 type Data = Bytes;
437 type Error = Box<dyn std::error::Error + Send + Sync>;
438
439 fn poll_frame(
440 self: Pin<&mut Self>,
441 cx: &mut Context<'_>,
442 ) -> Poll<Option<std::result::Result<http_body::Frame<Self::Data>, Self::Error>>> {
443 match self.project().source.project() {
444 ProjectedBodySource::Upstream { source } => source
445 .poll_frame(cx)
446 .map_ok(|f| f.map_data(Into::into))
447 .map_err(Into::into),
448 ProjectedBodySource::CachingUpstream { source } => source.poll_frame(cx),
449 ProjectedBodySource::File { source } => source.poll_frame(cx).map_err(Into::into),
450 }
451 }
452
453 fn is_end_stream(&self) -> bool {
454 match &self.source {
455 BodySource::Upstream { source } => source.is_end_stream(),
456 BodySource::CachingUpstream { source } => {
457 matches!(&source.state, CachingUpstreamSourceState::Completed)
458 }
459 BodySource::File { source } => source.finished,
460 }
461 }
462
463 fn size_hint(&self) -> http_body::SizeHint {
464 match &self.source {
465 BodySource::Upstream { source } => Body::size_hint(source),
466 BodySource::CachingUpstream { source } => match &source.state {
467 CachingUpstreamSourceState::ReadingUpstream { upstream, .. } => {
468 Body::size_hint(upstream)
469 }
470 _ => http_body::SizeHint::default(),
471 },
472 BodySource::File { source } => http_body::SizeHint::with_exact(source.len),
473 }
474 }
475}