1use std::io;
4use std::path::Path;
5use std::pin::Pin;
6use std::task::Context;
7use std::task::Poll;
8use std::task::ready;
9
10use anyhow::Context as _;
11use anyhow::Result;
12use blake3::Hasher;
13use bytes::Bytes;
14use bytes::BytesMut;
15use futures::Stream;
16use futures::future::BoxFuture;
17use http_body::Frame;
18use pin_project_lite::pin_project;
19use runtime::AsyncWrite;
20use tempfile::NamedTempFile;
21use tempfile::TempPath;
22
23use crate::HttpBody;
24use crate::runtime;
25
26const DEFAULT_CAPACITY: usize = 4096;
28
29pin_project! {
30 #[project = ProjectedCachingUpstreamSourceState]
32 enum CachingUpstreamSourceState<B> {
33 ReadingUpstream {
35 #[pin]
37 upstream: B,
38 #[pin]
40 writer: Option<runtime::BufWriter<runtime::File>>,
41 path: Option<TempPath>,
43 current: Bytes,
45 hasher: Hasher,
47 callback: Option<Box<dyn FnOnce(String, TempPath) -> BoxFuture<'static, Result<()>> + Send>>,
49 },
50 FlushingFile {
52 #[pin]
54 writer: Option<runtime::BufWriter<runtime::File>>,
55 path: Option<TempPath>,
57 digest: String,
59 callback: Option<Box<dyn FnOnce(String, TempPath) -> BoxFuture<'static, Result<()>> + Send>>,
61 },
62 InvokingCallback {
64 #[pin]
65 future: BoxFuture<'static, Result<()>>,
66 },
67 Completed
69 }
70}
71
72pin_project! {
73 struct CachingUpstreamSource<B> {
75 #[pin]
77 state: CachingUpstreamSourceState<B>,
78 }
79}
80
81impl<B> CachingUpstreamSource<B> {
82 async fn new<F>(upstream: B, temp_dir: &Path, callback: F) -> Result<Self>
86 where
87 F: FnOnce(String, TempPath) -> BoxFuture<'static, Result<()>> + Send + 'static,
88 {
89 let path = NamedTempFile::new_in(temp_dir)
90 .context("failed to create temporary body file for cache storage")?
91 .into_temp_path();
92
93 let file = runtime::File::create(&*path).await.with_context(|| {
94 format!(
95 "failed to create temporary body file `{path}`",
96 path = path.display()
97 )
98 })?;
99
100 Ok(Self {
101 state: CachingUpstreamSourceState::ReadingUpstream {
102 upstream,
103 writer: Some(runtime::BufWriter::new(file)),
104 path: Some(path),
105 callback: Some(Box::new(callback)),
106 current: Bytes::new(),
107 hasher: Hasher::new(),
108 },
109 })
110 }
111}
112
113impl<B> Stream for CachingUpstreamSource<B>
114where
115 B: HttpBody,
116{
117 type Item = io::Result<Bytes>;
118
119 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
120 loop {
121 let this = self.as_mut().project();
122 match this.state.project() {
123 ProjectedCachingUpstreamSourceState::ReadingUpstream {
124 upstream,
125 mut writer,
126 path,
127 current,
128 hasher,
129 callback,
130 } => {
131 if current.is_empty() {
133 match ready!(upstream.poll_next_data(cx)) {
134 Some(Ok(data)) if data.is_empty() => continue,
135 Some(Ok(data)) => {
136 hasher.update(&data);
138 *current = data;
139 }
140 Some(Err(e)) => {
141 self.set(Self {
143 state: CachingUpstreamSourceState::Completed,
144 });
145 return Poll::Ready(Some(Err(e)));
146 }
147 None => {
148 let writer = writer.take();
149 let path = path.take();
150 let digest = hex::encode(hasher.finalize().as_bytes());
151 let callback = callback.take();
152
153 self.set(Self {
156 state: CachingUpstreamSourceState::FlushingFile {
157 writer,
158 path,
159 digest,
160 callback,
161 },
162 });
163 continue;
164 }
165 }
166 }
167
168 let mut data = current.clone();
170 return match ready!(writer.as_pin_mut().unwrap().poll_write(cx, &data)) {
171 Ok(n) => {
172 *current = data.split_off(n);
173 Poll::Ready(Some(Ok(data)))
174 }
175 Err(e) => {
176 self.set(Self {
177 state: CachingUpstreamSourceState::Completed,
178 });
179 Poll::Ready(Some(Err(e)))
180 }
181 };
182 }
183 ProjectedCachingUpstreamSourceState::FlushingFile {
184 mut writer,
185 path,
186 digest,
187 callback,
188 } => {
189 match ready!(writer.as_mut().as_pin_mut().unwrap().poll_flush(cx)) {
191 Ok(_) => {
192 drop(writer.take());
193 let path = path.take().unwrap();
194 let digest = std::mem::take(digest);
195 let callback = callback.take().unwrap();
196
197 let future = callback(digest, path);
199 self.set(Self {
200 state: CachingUpstreamSourceState::InvokingCallback { future },
201 });
202 continue;
203 }
204 Err(e) => {
205 self.set(Self {
206 state: CachingUpstreamSourceState::Completed,
207 });
208 return Poll::Ready(Some(Err(e)));
209 }
210 }
211 }
212 ProjectedCachingUpstreamSourceState::InvokingCallback { future } => {
213 return match ready!(future.poll(cx)) {
214 Ok(_) => {
215 self.set(Self {
216 state: CachingUpstreamSourceState::Completed,
217 });
218 Poll::Ready(None)
219 }
220 Err(e) => {
221 self.set(Self {
222 state: CachingUpstreamSourceState::Completed,
223 });
224 Poll::Ready(Some(Err(io::Error::other(e))))
225 }
226 };
227 }
228 ProjectedCachingUpstreamSourceState::Completed => return Poll::Ready(None),
229 }
230 }
231 }
232}
233
234pin_project! {
235 struct FileSource {
237 #[pin]
239 reader: runtime::BufReader<runtime::File>,
240 len: u64,
242 buf: BytesMut,
244 finished: bool,
246 }
247}
248
249impl Stream for FileSource {
250 type Item = io::Result<Bytes>;
251
252 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
253 let this = self.project();
254
255 if *this.finished {
256 return Poll::Ready(None);
257 }
258
259 if this.buf.capacity() == 0 {
260 this.buf.reserve(DEFAULT_CAPACITY);
261 }
262
263 cfg_if::cfg_if! {
264 if #[cfg(feature = "tokio")] {
265 match ready!(tokio_util::io::poll_read_buf(this.reader, cx, this.buf)) {
266 Ok(0) => {
267 *this.finished = true;
268 Poll::Ready(None)
269 }
270 Ok(_) => {
271 let chunk = this.buf.split();
272 Poll::Ready(Some(Ok(chunk.freeze())))
273 }
274 Err(err) => {
275 *this.finished = true;
276 Poll::Ready(Some(Err(err)))
277 }
278 }
279 } else if #[cfg(feature = "smol")] {
280 use futures::AsyncRead;
281 use bytes::BufMut;
282
283 if !this.buf.has_remaining_mut() {
284 *this.finished = true;
285 return Poll::Ready(None);
286 }
287
288 let chunk = this.buf.chunk_mut();
289 let slice =
299 unsafe { std::slice::from_raw_parts_mut(chunk.as_mut_ptr(), chunk.len()) };
300 match ready!(this.reader.poll_read(cx, slice)) {
301 Ok(0) => {
302 *this.finished = true;
303 Poll::Ready(None)
304 }
305 Ok(n) => {
306 unsafe {
307 this.buf.advance_mut(n);
308 }
309 Poll::Ready(Some(Ok(this.buf.split().freeze())))
310 }
311 Err(e) => {
312 *this.finished = true;
313 Poll::Ready(Some(Err(e)))
314 }
315 }
316 } else {
317 unimplemented!()
318 }
319 }
320 }
321}
322
323pin_project! {
324 #[project = ProjectedBodySource]
332 enum BodySource<B> {
333 Upstream {
335 #[pin]
337 source: B
338 },
339 CachingUpstream {
341 #[pin]
343 source: CachingUpstreamSource<B>,
344 },
345 File {
347 #[pin]
349 source: FileSource
350 },
351 }
352}
353
354pin_project! {
355 pub struct Body<B> {
357 #[pin]
359 source: BodySource<B>
360 }
361}
362
363impl<B> Body<B>
364where
365 B: HttpBody,
366{
367 pub(crate) fn from_upstream(upstream: B) -> Self {
370 Self {
371 source: BodySource::Upstream { source: upstream },
372 }
373 }
374
375 pub(crate) async fn from_caching_upstream<F>(
378 upstream: B,
379 temp_dir: &Path,
380 callback: F,
381 ) -> Result<Self>
382 where
383 F: FnOnce(String, TempPath) -> BoxFuture<'static, Result<()>> + Send + 'static,
384 {
385 Ok(Self {
386 source: BodySource::CachingUpstream {
387 source: CachingUpstreamSource::new(upstream, temp_dir, callback).await?,
388 },
389 })
390 }
391
392 pub(crate) async fn from_file(file: runtime::File) -> Result<Self> {
394 let metadata = file.metadata().await?;
395
396 Ok(Self {
397 source: BodySource::File {
398 source: FileSource {
399 reader: runtime::BufReader::new(file),
400 len: metadata.len(),
401 buf: BytesMut::new(),
402 finished: false,
403 },
404 },
405 })
406 }
407}
408
409impl<B> http_body::Body for Body<B>
410where
411 B: HttpBody,
412{
413 type Data = Bytes;
414 type Error = io::Error;
415
416 fn poll_frame(
417 self: Pin<&mut Self>,
418 cx: &mut Context<'_>,
419 ) -> Poll<Option<Result<Frame<Self::Data>, io::Error>>> {
420 match self.project().source.project() {
421 ProjectedBodySource::Upstream { source } => source.poll_frame(cx),
422 ProjectedBodySource::CachingUpstream { source } => {
423 source.poll_next(cx).map_ok(Frame::data)
424 }
425 ProjectedBodySource::File { source } => source.poll_next(cx).map_ok(Frame::data),
426 }
427 }
428
429 fn is_end_stream(&self) -> bool {
430 match &self.source {
431 BodySource::Upstream { source } => source.is_end_stream(),
432 BodySource::CachingUpstream { source } => {
433 matches!(&source.state, CachingUpstreamSourceState::Completed)
434 }
435 BodySource::File { source } => source.finished,
436 }
437 }
438
439 fn size_hint(&self) -> http_body::SizeHint {
440 match &self.source {
441 BodySource::Upstream { source } => source.size_hint(),
442 BodySource::CachingUpstream { source } => match &source.state {
443 CachingUpstreamSourceState::ReadingUpstream { upstream, .. } => {
444 upstream.size_hint()
445 }
446 _ => http_body::SizeHint::default(),
447 },
448 BodySource::File { source } => http_body::SizeHint::with_exact(source.len),
449 }
450 }
451}
452
453impl<B> HttpBody for Body<B> where B: HttpBody + Send {}
454
455impl<B> Stream for Body<B>
461where
462 B: HttpBody,
463{
464 type Item = io::Result<Bytes>;
465
466 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
467 match self.project().source.project() {
468 ProjectedBodySource::Upstream { source } => source.poll_next_data(cx),
469 ProjectedBodySource::CachingUpstream { source } => source.poll_next(cx),
470 ProjectedBodySource::File { source } => source.poll_next(cx),
471 }
472 }
473}