google_cloud_storage/storage/
upload_source.rs1use std::collections::VecDeque;
18
19pub type SizeHint = http_body::SizeHint;
21
22pub struct Payload<T> {
44 payload: T,
45}
46
47impl<T> Payload<T>
48where
49 T: StreamingSource,
50{
51 pub fn from_stream(payload: T) -> Self {
52 Self { payload }
53 }
54}
55
56impl<T> StreamingSource for Payload<T>
57where
58 T: StreamingSource + Send + Sync,
59{
60 type Error = T::Error;
61
62 async fn next(&mut self) -> Option<Result<bytes::Bytes, Self::Error>> {
63 self.payload.next().await
64 }
65
66 async fn size_hint(&self) -> Result<SizeHint, Self::Error> {
67 self.payload.size_hint().await
68 }
69}
70
71impl<T> Seek for Payload<T>
72where
73 T: Seek,
74{
75 type Error = T::Error;
76
77 fn seek(&mut self, offset: u64) -> impl Future<Output = Result<(), Self::Error>> + Send {
78 self.payload.seek(offset)
79 }
80}
81
82impl From<bytes::Bytes> for Payload<BytesSource> {
83 fn from(value: bytes::Bytes) -> Self {
84 let payload = BytesSource::new(value);
85 Self { payload }
86 }
87}
88
89impl From<&'static str> for Payload<BytesSource> {
90 fn from(value: &'static str) -> Self {
91 let b = bytes::Bytes::from_static(value.as_bytes());
92 Payload::from(b)
93 }
94}
95
96impl From<Vec<bytes::Bytes>> for Payload<IterSource> {
97 fn from(value: Vec<bytes::Bytes>) -> Self {
98 let payload = IterSource::new(value);
99 Self { payload }
100 }
101}
102
103impl<S> From<S> for Payload<S>
104where
105 S: StreamingSource,
106{
107 fn from(value: S) -> Self {
108 Self { payload: value }
109 }
110}
111
112pub trait StreamingSource {
114 type Error: std::error::Error + Send + Sync + 'static;
116
117 fn next(&mut self) -> impl Future<Output = Option<Result<bytes::Bytes, Self::Error>>> + Send;
119
120 fn size_hint(&self) -> impl Future<Output = Result<SizeHint, Self::Error>> + Send {
128 std::future::ready(Ok(SizeHint::new()))
129 }
130}
131
132pub trait Seek {
139 type Error: std::error::Error + Send + Sync + 'static;
141
142 fn seek(&mut self, offset: u64) -> impl Future<Output = Result<(), Self::Error>> + Send;
151}
152
153const READ_SIZE: usize = 256 * 1024;
154
155impl From<tokio::fs::File> for Payload<FileSource> {
156 fn from(value: tokio::fs::File) -> Self {
157 Self {
158 payload: FileSource::new(value),
159 }
160 }
161}
162
163pub struct FileSource {
178 inner: tokio::fs::File,
179}
180
181impl FileSource {
182 fn new(inner: tokio::fs::File) -> Self {
183 Self { inner }
184 }
185}
186
187impl StreamingSource for FileSource {
188 type Error = std::io::Error;
189
190 async fn next(&mut self) -> Option<Result<bytes::Bytes, Self::Error>> {
191 let mut buffer = vec![0_u8; READ_SIZE];
192 match tokio::io::AsyncReadExt::read(&mut self.inner, &mut buffer).await {
193 Err(e) => Some(Err(e)),
194 Ok(0) => None,
195 Ok(n) => {
196 buffer.resize(n, 0_u8);
197 Some(Ok(bytes::Bytes::from_owner(buffer)))
198 }
199 }
200 }
201 async fn size_hint(&self) -> Result<SizeHint, Self::Error> {
202 let m = self.inner.metadata().await?;
203 Ok(SizeHint::with_exact(m.len()))
204 }
205}
206
207impl Seek for FileSource {
208 type Error = std::io::Error;
209
210 async fn seek(&mut self, offset: u64) -> Result<(), Self::Error> {
211 use tokio::io::AsyncSeekExt;
212 let _ = self.inner.seek(std::io::SeekFrom::Start(offset)).await?;
213 Ok(())
214 }
215}
216
217pub struct BytesSource {
232 contents: bytes::Bytes,
233 current: Option<bytes::Bytes>,
234}
235
236impl BytesSource {
237 pub(crate) fn new(contents: bytes::Bytes) -> Self {
238 let current = Some(contents.clone());
239 Self { contents, current }
240 }
241}
242
243impl StreamingSource for BytesSource {
244 type Error = crate::Error;
245
246 async fn next(&mut self) -> Option<Result<bytes::Bytes, Self::Error>> {
247 self.current.take().map(Result::Ok)
248 }
249
250 async fn size_hint(&self) -> Result<SizeHint, Self::Error> {
251 let s = self.contents.len() as u64;
252 Ok(SizeHint::with_exact(s))
253 }
254}
255
256impl Seek for BytesSource {
257 type Error = crate::Error;
258
259 async fn seek(&mut self, offset: u64) -> Result<(), Self::Error> {
260 let pos = std::cmp::min(offset as usize, self.contents.len());
261 self.current = Some(self.contents.slice(pos..));
262 Ok(())
263 }
264}
265
266pub(crate) struct IterSource {
268 contents: Vec<bytes::Bytes>,
269 current: VecDeque<bytes::Bytes>,
270}
271
272impl IterSource {
273 pub(crate) fn new<I>(iterator: I) -> Self
274 where
275 I: IntoIterator<Item = bytes::Bytes>,
276 {
277 let contents: Vec<bytes::Bytes> = iterator.into_iter().collect();
278 let current: VecDeque<bytes::Bytes> = contents.iter().cloned().collect();
279 Self { contents, current }
280 }
281}
282
283impl StreamingSource for IterSource {
284 type Error = std::io::Error;
285
286 async fn next(&mut self) -> Option<std::result::Result<bytes::Bytes, Self::Error>> {
287 self.current.pop_front().map(Ok)
288 }
289
290 async fn size_hint(&self) -> Result<SizeHint, Self::Error> {
291 let s = self.contents.iter().fold(0_u64, |a, i| a + i.len() as u64);
292 Ok(SizeHint::with_exact(s))
293 }
294}
295
296impl Seek for IterSource {
297 type Error = std::io::Error;
298 async fn seek(&mut self, offset: u64) -> std::result::Result<(), Self::Error> {
299 let mut current = VecDeque::new();
300 let mut offset = offset as usize;
301 for b in self.contents.iter() {
302 offset = match (offset, b.len()) {
303 (0, _) => {
304 current.push_back(b.clone());
305 0
306 }
307 (o, n) if o >= n => o - n,
308 (o, n) => {
309 current.push_back(b.clone().split_off(n - o));
310 0
311 }
312 }
313 }
314 self.current = current;
315 Ok(())
316 }
317}
318
319#[cfg(test)]
320pub mod tests {
321 use super::*;
322 use std::io::Write;
323 use tempfile::NamedTempFile;
324
325 type Result = anyhow::Result<()>;
326
327 const CONTENTS: &[u8] = b"how vexingly quick daft zebras jump";
328
329 pub(crate) struct UnknownSize {
330 inner: BytesSource,
331 }
332 impl UnknownSize {
333 pub fn new(inner: BytesSource) -> Self {
334 Self { inner }
335 }
336 }
337 impl Seek for UnknownSize {
338 type Error = <BytesSource as Seek>::Error;
339 async fn seek(&mut self, offset: u64) -> std::result::Result<(), Self::Error> {
340 self.inner.seek(offset).await
341 }
342 }
343 impl StreamingSource for UnknownSize {
344 type Error = <BytesSource as StreamingSource>::Error;
345 async fn next(&mut self) -> Option<std::result::Result<bytes::Bytes, Self::Error>> {
346 self.inner.next().await
347 }
348 async fn size_hint(&self) -> std::result::Result<SizeHint, Self::Error> {
349 let inner = self.inner.size_hint().await?;
350 let mut hint = SizeHint::default();
351 hint.set_lower(inner.lower());
352 Ok(hint)
353 }
354 }
355
356 mockall::mock! {
357 pub(crate) SimpleSource {}
358
359 impl StreamingSource for SimpleSource {
360 type Error = std::io::Error;
361 async fn next(&mut self) -> Option<std::result::Result<bytes::Bytes, std::io::Error>>;
362 async fn size_hint(&self) -> std::result::Result<SizeHint, std::io::Error>;
363 }
364 }
365
366 mockall::mock! {
367 pub(crate) SeekSource {}
368
369 impl StreamingSource for SeekSource {
370 type Error = std::io::Error;
371 async fn next(&mut self) -> Option<std::result::Result<bytes::Bytes, std::io::Error>>;
372 async fn size_hint(&self) -> std::result::Result<SizeHint, std::io::Error>;
373 }
374 impl Seek for SeekSource {
375 type Error = std::io::Error;
376 async fn seek(&mut self, offset: u64) ->std::result::Result<(), std::io::Error>;
377 }
378 }
379
380 async fn collect<S>(mut source: S) -> anyhow::Result<Vec<u8>>
382 where
383 S: StreamingSource,
384 {
385 collect_mut(&mut source).await
386 }
387
388 async fn collect_mut<S>(source: &mut S) -> anyhow::Result<Vec<u8>>
390 where
391 S: StreamingSource,
392 {
393 let mut vec = Vec::new();
394 while let Some(bytes) = source.next().await.transpose()? {
395 vec.extend_from_slice(&bytes);
396 }
397 Ok(vec)
398 }
399
400 #[tokio::test]
401 async fn empty_bytes() -> Result {
402 let buffer = Payload::from(bytes::Bytes::default());
403 let range = buffer.size_hint().await?;
404 assert_eq!(range.exact(), Some(0));
405 let got = collect(buffer).await?;
406 assert!(got.is_empty(), "{got:?}");
407
408 Ok(())
409 }
410
411 #[tokio::test]
412 async fn simple_bytes() -> Result {
413 let buffer = Payload::from(bytes::Bytes::from_static(CONTENTS));
414 let range = buffer.size_hint().await?;
415 assert_eq!(range.exact(), Some(CONTENTS.len() as u64));
416 let got = collect(buffer).await?;
417 assert_eq!(got[..], CONTENTS[..], "{got:?}");
418 Ok(())
419 }
420
421 #[tokio::test]
422 async fn simple_str() -> Result {
423 const LAZY: &str = "the quick brown fox jumps over the lazy dog";
424 let buffer = Payload::from(LAZY);
425 let range = buffer.size_hint().await?;
426 assert_eq!(range.exact(), Some(LAZY.len() as u64));
427 let got = collect(buffer).await?;
428 assert_eq!(&got, LAZY.as_bytes(), "{got:?}");
429 Ok(())
430 }
431
432 #[tokio::test]
433 async fn seek_bytes() -> Result {
434 let mut buffer = Payload::from(bytes::Bytes::from_static(CONTENTS));
435 buffer.seek(8).await?;
436 let got = collect(buffer).await?;
437 assert_eq!(got[..], CONTENTS[8..], "{got:?}");
438 Ok(())
439 }
440
441 #[tokio::test]
442 async fn empty_stream() -> Result {
443 let source = IterSource::new(vec![]);
444 let payload = Payload::from(source);
445 let range = payload.size_hint().await?;
446 assert_eq!(range.exact(), Some(0));
447 let got = collect(payload).await?;
448 assert!(got.is_empty(), "{got:?}");
449
450 Ok(())
451 }
452
453 #[tokio::test]
454 async fn simple_stream() -> Result {
455 let source = IterSource::new(
456 ["how ", "vexingly ", "quick ", "daft ", "zebras ", "jump"]
457 .map(|v| bytes::Bytes::from_static(v.as_bytes())),
458 );
459 let payload = Payload::from_stream(source);
460 let got = collect(payload).await?;
461 assert_eq!(got[..], CONTENTS[..]);
462
463 Ok(())
464 }
465
466 #[tokio::test]
467 async fn empty_file() -> Result {
468 let file = NamedTempFile::new()?;
469 let read = tokio::fs::File::from(file.reopen()?);
470 let payload = Payload::from(read);
471 let hint = payload.size_hint().await?;
472 assert_eq!(hint.exact(), Some(0));
473 let got = collect(payload).await?;
474 assert!(got.is_empty(), "{got:?}");
475 Ok(())
476 }
477
478 #[tokio::test]
479 async fn small_file() -> Result {
480 let mut file = NamedTempFile::new()?;
481 assert_eq!(file.write(CONTENTS)?, CONTENTS.len());
482 file.flush()?;
483 let read = tokio::fs::File::from(file.reopen()?);
484 let payload = Payload::from(read);
485 let hint = payload.size_hint().await?;
486 let s = CONTENTS.len() as u64;
487 assert_eq!(hint.exact(), Some(s));
488 let got = collect(payload).await?;
489 assert_eq!(got[..], CONTENTS[..], "{got:?}");
490 Ok(())
491 }
492
493 #[tokio::test]
494 async fn small_file_seek() -> Result {
495 let mut file = NamedTempFile::new()?;
496 assert_eq!(file.write(CONTENTS)?, CONTENTS.len());
497 file.flush()?;
498 let read = tokio::fs::File::from(file.reopen()?);
499 let mut payload = Payload::from(read);
500 payload.seek(8).await?;
501 let got = collect(payload).await?;
502 assert_eq!(got[..], CONTENTS[8..], "{got:?}");
503 Ok(())
504 }
505
506 #[tokio::test]
507 async fn larger_file() -> Result {
508 let mut file = NamedTempFile::new()?;
509 assert_eq!(file.write(&[0_u8; READ_SIZE])?, READ_SIZE);
510 assert_eq!(file.write(&[1_u8; READ_SIZE])?, READ_SIZE);
511 assert_eq!(file.write(&[2_u8; READ_SIZE])?, READ_SIZE);
512 assert_eq!(file.write(&[3_u8; READ_SIZE])?, READ_SIZE);
513 file.flush()?;
514 assert_eq!(READ_SIZE % 2, 0);
515 let read = tokio::fs::File::from(file.reopen()?);
516 let mut payload = Payload::from(read);
517 payload.seek((READ_SIZE + READ_SIZE / 2) as u64).await?;
518 let got = collect(payload).await?;
519 let mut want = Vec::new();
520 want.extend_from_slice(&[1_u8; READ_SIZE / 2]);
521 want.extend_from_slice(&[2_u8; READ_SIZE]);
522 want.extend_from_slice(&[3_u8; READ_SIZE]);
523 assert_eq!(got[..], want[..], "{got:?}");
524 Ok(())
525 }
526
527 #[tokio::test]
528 async fn iter_source_full() -> Result {
529 const N: usize = 32;
530 let mut buf = Vec::new();
531 buf.extend_from_slice(&[1_u8; N]);
532 buf.extend_from_slice(&[2_u8; N]);
533 buf.extend_from_slice(&[3_u8; N]);
534 let b = bytes::Bytes::from_owner(buf);
535
536 let mut stream =
537 IterSource::new(vec![b.slice(0..N), b.slice(N..(2 * N)), b.slice((2 * N)..)]);
538 assert_eq!(stream.size_hint().await?.exact(), Some(3 * N as u64));
539
540 for offset in [0, N / 2, 0, N, 0, 2 * N + N / 2] {
543 stream.seek(offset as u64).await?;
544 let got = collect_mut(&mut stream).await?;
545 assert_eq!(got[..], b[offset..(3 * N)]);
546 }
547
548 Ok(())
549 }
550}