google_cloud_storage/storage/
upload_source.rs

1// Copyright 2025 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Defines upload data sources.
16
17use std::collections::VecDeque;
18
19/// The *total* number of bytes expected in a [StreamingSource].
20pub type SizeHint = http_body::SizeHint;
21
22/// The payload for object uploads via the [Storage][crate::client::Storage]
23/// client.
24///
25/// The storage client functions to upload new objects consume any type that can
26/// be converted to this type. That includes simple buffers, and any type
27/// implementing [StreamingSource].
28///
29/// # Example
30/// ```
31/// # tokio_test::block_on(async {
32/// # use google_cloud_storage::upload_source::Payload;
33/// use google_cloud_storage::upload_source::StreamingSource;
34/// let buffer : &[u8] = b"the quick brown fox jumps over the lazy dog";
35/// let mut size = 0_usize;
36/// let mut payload = Payload::from(bytes::Bytes::from_static(buffer));
37/// while let Some(bytes) = payload.next().await.transpose()? {
38///     size += bytes.len();
39/// }
40/// assert_eq!(size, buffer.len());
41/// # anyhow::Result::<()>::Ok(()) });
42/// ```
43pub struct Payload<T> {
44    payload: T,
45}
46
47impl<T> Payload<T>
48where
49    T: StreamingSource,
50{
51    pub fn from_stream(payload: T) -> Self {
52        Self { payload }
53    }
54}
55
56impl<T> StreamingSource for Payload<T>
57where
58    T: StreamingSource + Send + Sync,
59{
60    type Error = T::Error;
61
62    async fn next(&mut self) -> Option<Result<bytes::Bytes, Self::Error>> {
63        self.payload.next().await
64    }
65
66    async fn size_hint(&self) -> Result<SizeHint, Self::Error> {
67        self.payload.size_hint().await
68    }
69}
70
71impl<T> Seek for Payload<T>
72where
73    T: Seek,
74{
75    type Error = T::Error;
76
77    fn seek(&mut self, offset: u64) -> impl Future<Output = Result<(), Self::Error>> + Send {
78        self.payload.seek(offset)
79    }
80}
81
82impl From<bytes::Bytes> for Payload<BytesSource> {
83    fn from(value: bytes::Bytes) -> Self {
84        let payload = BytesSource::new(value);
85        Self { payload }
86    }
87}
88
89impl From<&'static str> for Payload<BytesSource> {
90    fn from(value: &'static str) -> Self {
91        let b = bytes::Bytes::from_static(value.as_bytes());
92        Payload::from(b)
93    }
94}
95
96impl From<Vec<bytes::Bytes>> for Payload<IterSource> {
97    fn from(value: Vec<bytes::Bytes>) -> Self {
98        let payload = IterSource::new(value);
99        Self { payload }
100    }
101}
102
103impl<S> From<S> for Payload<S>
104where
105    S: StreamingSource,
106{
107    fn from(value: S) -> Self {
108        Self { payload: value }
109    }
110}
111
112/// Provides bytes for an upload from single-pass sources.
113pub trait StreamingSource {
114    /// The error type.
115    type Error: std::error::Error + Send + Sync + 'static;
116
117    /// Gets the next set of data to upload.
118    fn next(&mut self) -> impl Future<Output = Option<Result<bytes::Bytes, Self::Error>>> + Send;
119
120    /// An estimate of the upload size.
121    ///
122    /// Returns the expected size as a [min, max) range. Where `None` represents
123    /// an unknown limit for the upload.
124    ///
125    /// If the maximum size is known and sufficiently small, the client library
126    /// may be able to use a more efficient protocol for the upload.
127    fn size_hint(&self) -> impl Future<Output = Result<SizeHint, Self::Error>> + Send {
128        std::future::ready(Ok(SizeHint::new()))
129    }
130}
131
132/// Provides bytes for an upload from sources that support seek.
133///
134/// Implementations of this trait provide data for Google Cloud Storage uploads.
135/// The data may be received asynchronously, such as downloads from Google Cloud
136/// Storage, other remote storage systems, or the result of repeatable
137/// computations.
138pub trait Seek {
139    /// The error type.
140    type Error: std::error::Error + Send + Sync + 'static;
141
142    /// Resets the stream to start from `offset`.
143    ///
144    /// The client library automatically restarts uploads when the connection
145    /// is reset or there is some kind of partial failure. Resuming an upload
146    /// may require resetting the stream to an arbitrary point.
147    ///
148    /// The client library assumes that `seek(N)` followed by `next()` always
149    /// returns the same data.
150    fn seek(&mut self, offset: u64) -> impl Future<Output = Result<(), Self::Error>> + Send;
151}
152
153const READ_SIZE: usize = 256 * 1024;
154
155impl From<tokio::fs::File> for Payload<FileSource> {
156    fn from(value: tokio::fs::File) -> Self {
157        Self {
158            payload: FileSource::new(value),
159        }
160    }
161}
162
163/// Implements [StreamingSource] for a [tokio::fs::File].
164///
165/// # Example
166/// ```
167/// # use google_cloud_storage::client::Storage;
168/// # async fn sample(client: &Storage) -> anyhow::Result<()> {
169/// let payload = tokio::fs::File::open("my-data").await?;
170/// let response = client
171///     .upload_object("projects/_/buckets/my-bucket", "my-object", payload)
172///     .send_unbuffered()
173///     .await?;
174/// println!("response details={response:?}");
175/// # Ok(()) }
176/// ```
177pub struct FileSource {
178    inner: tokio::fs::File,
179}
180
181impl FileSource {
182    fn new(inner: tokio::fs::File) -> Self {
183        Self { inner }
184    }
185}
186
187impl StreamingSource for FileSource {
188    type Error = std::io::Error;
189
190    async fn next(&mut self) -> Option<Result<bytes::Bytes, Self::Error>> {
191        let mut buffer = vec![0_u8; READ_SIZE];
192        match tokio::io::AsyncReadExt::read(&mut self.inner, &mut buffer).await {
193            Err(e) => Some(Err(e)),
194            Ok(0) => None,
195            Ok(n) => {
196                buffer.resize(n, 0_u8);
197                Some(Ok(bytes::Bytes::from_owner(buffer)))
198            }
199        }
200    }
201    async fn size_hint(&self) -> Result<SizeHint, Self::Error> {
202        let m = self.inner.metadata().await?;
203        Ok(SizeHint::with_exact(m.len()))
204    }
205}
206
207impl Seek for FileSource {
208    type Error = std::io::Error;
209
210    async fn seek(&mut self, offset: u64) -> Result<(), Self::Error> {
211        use tokio::io::AsyncSeekExt;
212        let _ = self.inner.seek(std::io::SeekFrom::Start(offset)).await?;
213        Ok(())
214    }
215}
216
217/// Implements [StreamingSource] for [bytes::Bytes].
218///
219/// # Example
220/// ```
221/// # use google_cloud_storage::client::Storage;
222/// # async fn sample(client: &Storage) -> anyhow::Result<()> {
223/// let payload = bytes::Bytes::from_static(b"Hello World!");
224/// let response = client
225///     .upload_object("projects/_/buckets/my-bucket", "my-object", payload)
226///     .send_unbuffered()
227///     .await?;
228/// println!("response details={response:?}");
229/// # Ok(()) }
230/// ```
231pub struct BytesSource {
232    contents: bytes::Bytes,
233    current: Option<bytes::Bytes>,
234}
235
236impl BytesSource {
237    pub(crate) fn new(contents: bytes::Bytes) -> Self {
238        let current = Some(contents.clone());
239        Self { contents, current }
240    }
241}
242
243impl StreamingSource for BytesSource {
244    type Error = crate::Error;
245
246    async fn next(&mut self) -> Option<Result<bytes::Bytes, Self::Error>> {
247        self.current.take().map(Result::Ok)
248    }
249
250    async fn size_hint(&self) -> Result<SizeHint, Self::Error> {
251        let s = self.contents.len() as u64;
252        Ok(SizeHint::with_exact(s))
253    }
254}
255
256impl Seek for BytesSource {
257    type Error = crate::Error;
258
259    async fn seek(&mut self, offset: u64) -> Result<(), Self::Error> {
260        let pos = std::cmp::min(offset as usize, self.contents.len());
261        self.current = Some(self.contents.slice(pos..));
262        Ok(())
263    }
264}
265
266/// Implements [StreamingSource] for a sequence of [bytes::Bytes].
267pub(crate) struct IterSource {
268    contents: Vec<bytes::Bytes>,
269    current: VecDeque<bytes::Bytes>,
270}
271
272impl IterSource {
273    pub(crate) fn new<I>(iterator: I) -> Self
274    where
275        I: IntoIterator<Item = bytes::Bytes>,
276    {
277        let contents: Vec<bytes::Bytes> = iterator.into_iter().collect();
278        let current: VecDeque<bytes::Bytes> = contents.iter().cloned().collect();
279        Self { contents, current }
280    }
281}
282
283impl StreamingSource for IterSource {
284    type Error = std::io::Error;
285
286    async fn next(&mut self) -> Option<std::result::Result<bytes::Bytes, Self::Error>> {
287        self.current.pop_front().map(Ok)
288    }
289
290    async fn size_hint(&self) -> Result<SizeHint, Self::Error> {
291        let s = self.contents.iter().fold(0_u64, |a, i| a + i.len() as u64);
292        Ok(SizeHint::with_exact(s))
293    }
294}
295
296impl Seek for IterSource {
297    type Error = std::io::Error;
298    async fn seek(&mut self, offset: u64) -> std::result::Result<(), Self::Error> {
299        let mut current = VecDeque::new();
300        let mut offset = offset as usize;
301        for b in self.contents.iter() {
302            offset = match (offset, b.len()) {
303                (0, _) => {
304                    current.push_back(b.clone());
305                    0
306                }
307                (o, n) if o >= n => o - n,
308                (o, n) => {
309                    current.push_back(b.clone().split_off(n - o));
310                    0
311                }
312            }
313        }
314        self.current = current;
315        Ok(())
316    }
317}
318
319#[cfg(test)]
320pub mod tests {
321    use super::*;
322    use std::io::Write;
323    use tempfile::NamedTempFile;
324
325    type Result = anyhow::Result<()>;
326
327    const CONTENTS: &[u8] = b"how vexingly quick daft zebras jump";
328
329    pub(crate) struct UnknownSize {
330        inner: BytesSource,
331    }
332    impl UnknownSize {
333        pub fn new(inner: BytesSource) -> Self {
334            Self { inner }
335        }
336    }
337    impl Seek for UnknownSize {
338        type Error = <BytesSource as Seek>::Error;
339        async fn seek(&mut self, offset: u64) -> std::result::Result<(), Self::Error> {
340            self.inner.seek(offset).await
341        }
342    }
343    impl StreamingSource for UnknownSize {
344        type Error = <BytesSource as StreamingSource>::Error;
345        async fn next(&mut self) -> Option<std::result::Result<bytes::Bytes, Self::Error>> {
346            self.inner.next().await
347        }
348        async fn size_hint(&self) -> std::result::Result<SizeHint, Self::Error> {
349            let inner = self.inner.size_hint().await?;
350            let mut hint = SizeHint::default();
351            hint.set_lower(inner.lower());
352            Ok(hint)
353        }
354    }
355
356    mockall::mock! {
357        pub(crate) SimpleSource {}
358
359        impl StreamingSource for SimpleSource {
360            type Error = std::io::Error;
361            async fn next(&mut self) -> Option<std::result::Result<bytes::Bytes, std::io::Error>>;
362            async fn size_hint(&self) -> std::result::Result<SizeHint, std::io::Error>;
363        }
364    }
365
366    mockall::mock! {
367        pub(crate) SeekSource {}
368
369        impl StreamingSource for SeekSource {
370            type Error = std::io::Error;
371            async fn next(&mut self) -> Option<std::result::Result<bytes::Bytes, std::io::Error>>;
372            async fn size_hint(&self) -> std::result::Result<SizeHint, std::io::Error>;
373        }
374        impl Seek for SeekSource {
375            type Error = std::io::Error;
376            async fn seek(&mut self, offset: u64) ->std::result::Result<(), std::io::Error>;
377        }
378    }
379
380    /// A helper function to simplify the tests.
381    async fn collect<S>(mut source: S) -> anyhow::Result<Vec<u8>>
382    where
383        S: StreamingSource,
384    {
385        collect_mut(&mut source).await
386    }
387
388    /// A helper function to simplify the tests.
389    async fn collect_mut<S>(source: &mut S) -> anyhow::Result<Vec<u8>>
390    where
391        S: StreamingSource,
392    {
393        let mut vec = Vec::new();
394        while let Some(bytes) = source.next().await.transpose()? {
395            vec.extend_from_slice(&bytes);
396        }
397        Ok(vec)
398    }
399
400    #[tokio::test]
401    async fn empty_bytes() -> Result {
402        let buffer = Payload::from(bytes::Bytes::default());
403        let range = buffer.size_hint().await?;
404        assert_eq!(range.exact(), Some(0));
405        let got = collect(buffer).await?;
406        assert!(got.is_empty(), "{got:?}");
407
408        Ok(())
409    }
410
411    #[tokio::test]
412    async fn simple_bytes() -> Result {
413        let buffer = Payload::from(bytes::Bytes::from_static(CONTENTS));
414        let range = buffer.size_hint().await?;
415        assert_eq!(range.exact(), Some(CONTENTS.len() as u64));
416        let got = collect(buffer).await?;
417        assert_eq!(got[..], CONTENTS[..], "{got:?}");
418        Ok(())
419    }
420
421    #[tokio::test]
422    async fn simple_str() -> Result {
423        const LAZY: &str = "the quick brown fox jumps over the lazy dog";
424        let buffer = Payload::from(LAZY);
425        let range = buffer.size_hint().await?;
426        assert_eq!(range.exact(), Some(LAZY.len() as u64));
427        let got = collect(buffer).await?;
428        assert_eq!(&got, LAZY.as_bytes(), "{got:?}");
429        Ok(())
430    }
431
432    #[tokio::test]
433    async fn seek_bytes() -> Result {
434        let mut buffer = Payload::from(bytes::Bytes::from_static(CONTENTS));
435        buffer.seek(8).await?;
436        let got = collect(buffer).await?;
437        assert_eq!(got[..], CONTENTS[8..], "{got:?}");
438        Ok(())
439    }
440
441    #[tokio::test]
442    async fn empty_stream() -> Result {
443        let source = IterSource::new(vec![]);
444        let payload = Payload::from(source);
445        let range = payload.size_hint().await?;
446        assert_eq!(range.exact(), Some(0));
447        let got = collect(payload).await?;
448        assert!(got.is_empty(), "{got:?}");
449
450        Ok(())
451    }
452
453    #[tokio::test]
454    async fn simple_stream() -> Result {
455        let source = IterSource::new(
456            ["how ", "vexingly ", "quick ", "daft ", "zebras ", "jump"]
457                .map(|v| bytes::Bytes::from_static(v.as_bytes())),
458        );
459        let payload = Payload::from_stream(source);
460        let got = collect(payload).await?;
461        assert_eq!(got[..], CONTENTS[..]);
462
463        Ok(())
464    }
465
466    #[tokio::test]
467    async fn empty_file() -> Result {
468        let file = NamedTempFile::new()?;
469        let read = tokio::fs::File::from(file.reopen()?);
470        let payload = Payload::from(read);
471        let hint = payload.size_hint().await?;
472        assert_eq!(hint.exact(), Some(0));
473        let got = collect(payload).await?;
474        assert!(got.is_empty(), "{got:?}");
475        Ok(())
476    }
477
478    #[tokio::test]
479    async fn small_file() -> Result {
480        let mut file = NamedTempFile::new()?;
481        assert_eq!(file.write(CONTENTS)?, CONTENTS.len());
482        file.flush()?;
483        let read = tokio::fs::File::from(file.reopen()?);
484        let payload = Payload::from(read);
485        let hint = payload.size_hint().await?;
486        let s = CONTENTS.len() as u64;
487        assert_eq!(hint.exact(), Some(s));
488        let got = collect(payload).await?;
489        assert_eq!(got[..], CONTENTS[..], "{got:?}");
490        Ok(())
491    }
492
493    #[tokio::test]
494    async fn small_file_seek() -> Result {
495        let mut file = NamedTempFile::new()?;
496        assert_eq!(file.write(CONTENTS)?, CONTENTS.len());
497        file.flush()?;
498        let read = tokio::fs::File::from(file.reopen()?);
499        let mut payload = Payload::from(read);
500        payload.seek(8).await?;
501        let got = collect(payload).await?;
502        assert_eq!(got[..], CONTENTS[8..], "{got:?}");
503        Ok(())
504    }
505
506    #[tokio::test]
507    async fn larger_file() -> Result {
508        let mut file = NamedTempFile::new()?;
509        assert_eq!(file.write(&[0_u8; READ_SIZE])?, READ_SIZE);
510        assert_eq!(file.write(&[1_u8; READ_SIZE])?, READ_SIZE);
511        assert_eq!(file.write(&[2_u8; READ_SIZE])?, READ_SIZE);
512        assert_eq!(file.write(&[3_u8; READ_SIZE])?, READ_SIZE);
513        file.flush()?;
514        assert_eq!(READ_SIZE % 2, 0);
515        let read = tokio::fs::File::from(file.reopen()?);
516        let mut payload = Payload::from(read);
517        payload.seek((READ_SIZE + READ_SIZE / 2) as u64).await?;
518        let got = collect(payload).await?;
519        let mut want = Vec::new();
520        want.extend_from_slice(&[1_u8; READ_SIZE / 2]);
521        want.extend_from_slice(&[2_u8; READ_SIZE]);
522        want.extend_from_slice(&[3_u8; READ_SIZE]);
523        assert_eq!(got[..], want[..], "{got:?}");
524        Ok(())
525    }
526
527    #[tokio::test]
528    async fn iter_source_full() -> Result {
529        const N: usize = 32;
530        let mut buf = Vec::new();
531        buf.extend_from_slice(&[1_u8; N]);
532        buf.extend_from_slice(&[2_u8; N]);
533        buf.extend_from_slice(&[3_u8; N]);
534        let b = bytes::Bytes::from_owner(buf);
535
536        let mut stream =
537            IterSource::new(vec![b.slice(0..N), b.slice(N..(2 * N)), b.slice((2 * N)..)]);
538        assert_eq!(stream.size_hint().await?.exact(), Some(3 * N as u64));
539
540        // test_case() is not appropriate here: we want to verify seek() works
541        // multiple times over the *same* stream.
542        for offset in [0, N / 2, 0, N, 0, 2 * N + N / 2] {
543            stream.seek(offset as u64).await?;
544            let got = collect_mut(&mut stream).await?;
545            assert_eq!(got[..], b[offset..(3 * N)]);
546        }
547
548        Ok(())
549    }
550}