cobalt_aws/s3/
multipartcopy.rs

1use aws_sdk_s3::error::SdkError;
2use aws_sdk_s3::operation::complete_multipart_upload::CompleteMultipartUploadError;
3use aws_sdk_s3::operation::create_multipart_upload::{
4    CreateMultipartUploadError, CreateMultipartUploadOutput,
5};
6use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart};
7use aws_sdk_s3::Client;
8use bytesize::{GIB, MIB, TIB};
9use conv::{prelude::*, FloatError, RangeError};
10use derive_more::{AsRef, Display, Into};
11use either::Either;
12use futures::stream::TryStreamExt;
13use std::error::Error as StdError;
14use std::num::TryFromIntError;
15use std::sync::Arc;
16use thiserror::Error;
17use tracing::instrument;
18use typed_builder::TypedBuilder;
19
20use super::S3Object;
21
22/// Retrieves the size of the source object in the specified S3 bucket.
23///
24/// # Arguments
25///
26/// * `client` - A reference to the S3 client.
27/// * `bucket` - The name of the S3 bucket.
28/// * `key` - The key of the source object.
29///
30/// # Returns
31///
32/// The size of the source object in bytes.
33#[instrument(skip(client))]
34async fn get_source_size(
35    client: &Client,
36    bucket: &str,
37    key: &str,
38) -> Result<SourceSize, S3MultipartCopierError> {
39    let head_object = client.head_object().bucket(bucket).key(key).send().await?;
40
41    let length = head_object
42        .content_length()
43        .ok_or(S3MultipartCopierError::MissingContentLength)?;
44    let source = SourceSize::try_from(length).map_err(S3MultipartCopierError::SourceSize)?;
45    Ok(source)
46}
47
48/// The minimum allowed source size for an S3 object, set to 0 bytes.
49///
50/// This constant defines the minimum size an S3 object can be for operations that
51/// require a size check. Setting it to 0 bytes allows for empty objects, which are
52/// valid in S3 but might be restricted in some specific use cases.
53const MIN_SOURCE_SIZE: i64 = 0;
54
55/// The maximum allowed source size for an S3 object, set to 5 TiB.
56///
57/// This constant defines the maximum size an S3 object can be for operations that
58/// require a size check. S3 objects can be up to 5 TiB (5 * 1024 * 1024 * 1024 * 1024 bytes) in size.
59/// This limit ensures that objects are manageable and within S3's service constraints.
60const MAX_SOURCE_SIZE: i64 = 5 * TIB as i64;
61
62/// Errors that can occur when creating a `SourceSize`.
63///
64/// `SourceSizeError` is used to indicate that a given size is either too small or too large
65/// to be used as an S3 object size.
66///
67/// # Variants
68///
69/// - `TooSmall(i64)`: Indicates that the source size is smaller than the minimum allowed size.
70/// - `TooLarge(i64)`: Indicates that the source size is larger than the maximum allowed size.
71#[derive(Debug, Error)]
72pub enum SourceSizeError {
73    #[error("S3 Object must be at least {MIN_SOURCE_SIZE} bytes. Object size was {0}")]
74    TooSmall(i64),
75    #[error("S3 Object must be at most {MAX_SOURCE_SIZE} bytes, Object size was {0}")]
76    TooLarge(i64),
77}
78
79/// Represents a valid source size for an S3 object.
80///
81/// `SourceSize` ensures that the size of an S3 object is within the allowed range defined by S3.
82/// The size must be at least 0 bytes and at most 5 TiB.
83#[derive(Debug, Display, Into, AsRef, Clone, Eq, PartialEq)]
84#[into(owned, ref, ref_mut)]
85pub struct SourceSize(i64);
86
87/// Attempts to create a `SourceSize` from an `i64` value.
88///
89/// The `TryFrom<i64>` implementation for `SourceSize` ensures that the given value is within
90/// the allowed range for S3 object sizes. If the value is within the range, it returns `Ok(SourceSize)`.
91/// Otherwise, it returns an appropriate `SourceSizeError`.
92///
93/// # Errors
94///
95/// - Returns `SourceSizeError::TooSmall` if the value is smaller than `MIN_SOURCE_SIZE`.
96/// - Returns `SourceSizeError::TooLarge` if the value is larger than `MAX_SOURCE_SIZE`.
97impl TryFrom<i64> for SourceSize {
98    type Error = SourceSizeError;
99
100    fn try_from(value: i64) -> Result<Self, Self::Error> {
101        if value < MIN_SOURCE_SIZE {
102            Err(SourceSizeError::TooSmall(value))
103        } else if value > MAX_SOURCE_SIZE {
104            Err(SourceSizeError::TooLarge(value))
105        } else {
106            Ok(SourceSize(value))
107        }
108    }
109}
110
111/// The minimum allowed part size for S3 multipart uploads, set to 5 MiB.
112///
113/// S3 enforces a minimum part size of 5 MiB (5 * 1024 * 1024 bytes) for multipart uploads.
114/// Parts smaller than this size are not allowed, except for the last part of the upload.
115/// Ensuring that each part meets this minimum size requirement helps optimize the upload process
116/// and ensures compatibility with S3's multipart upload API.
117const MIN_PART_SIZE: i64 = 5 * MIB as i64;
118
119// The maximum allowed part size for S3 multipart uploads, set to 5 GiB.
120///
121/// S3 enforces a maximum part size of 5 GiB (5 * 1024 * 1024 * 1024 bytes) for multipart uploads.
122/// Parts larger than this size are not allowed. This limitation helps prevent excessively large
123/// parts from overwhelming the upload process and ensures that the upload is broken down into manageable
124/// chunks. Adhering to this maximum size requirement is essential for successful multipart uploads.
125const MAX_PART_SIZE: i64 = 5 * GIB as i64;
126
127/// Represents a valid part size for S3 multipart uploads.
128///
129/// `PartSize` ensures that the size of each part used in multipart uploads is within
130/// the allowed range defined by S3. The size must be at least 5 MB and at most 5 GB.
131///
132/// # Constants
133///
134/// - `MIN_PART_SIZE`: The minimum allowed part size (5 MB).
135/// - `MAX_PART_SIZE`: The maximum allowed part size (5 GB).
136#[derive(Debug, Into, AsRef, Clone)]
137#[into(owned, ref, ref_mut)]
138pub struct PartSize(i64);
139
140/// The default size of a block for S3 multipart copy operations, set to 50 MiB.
141///
142/// # Why 50 MiB?
143///
144/// - **Balance**: Optimizes between throughput and the number of API calls.
145/// - **S3 Limits**: Fits within S3's part size requirements (min 5 MiB, max 5 TB).
146/// - **Parallelism**: Allows efficient parallel uploads, speeding up the copy process.
147/// - **Error Handling**: Facilitates easier retries of failed parts without re-uploading the entire object.
148///
149/// This size ensures efficient, cost-effective, and reliable multipart copy operations.
150///
151/// # Note
152///
153/// While 50 MiB is a good default for many use cases, it might not be suitable for all operations.
154/// Adjust the part size based on your specific requirements and constraints.
155impl Default for PartSize {
156    fn default() -> Self {
157        const DEFAULT_COPY_PART_SIZE: i64 = 50 * MIB as i64;
158        Self(DEFAULT_COPY_PART_SIZE)
159    }
160}
161
162#[derive(Debug, Error)]
163pub enum PartSizeError {
164    #[error("part_size must be at least {MIN_PART_SIZE} bytes. part_size was {0}")]
165    TooSmall(i64),
166    #[error("part_size must be at most {MAX_PART_SIZE} bytes, part_size was {0}")]
167    TooLarge(i64),
168}
169
170/// Attempts to create a `PartSize` from an `i64` value.
171///
172/// The `TryFrom<i64>` implementation for `PartSize` ensures that the given value is within
173/// the allowed range for S3 multipart upload part sizes. If the value is within the range,
174/// it returns `Ok(PartSize)`. Otherwise, it returns an appropriate `PartSizeError`.
175///
176/// # Errors
177///
178/// - Returns `PartSizeError::TooSmall` if the value is smaller than `MIN_PART_SIZE`.
179/// - Returns `PartSizeError::TooLarge` if the value is larger than `MAX_PART_SIZE`.
180impl TryFrom<i64> for PartSize {
181    type Error = PartSizeError;
182
183    fn try_from(value: i64) -> Result<Self, Self::Error> {
184        if value < MIN_PART_SIZE {
185            Err(PartSizeError::TooSmall(value))
186        } else if value > MAX_PART_SIZE {
187            Err(PartSizeError::TooLarge(value))
188        } else {
189            Ok(PartSize(value))
190        }
191    }
192}
193
194/// Errors that can occur when creating a `ByteRange`.
195///
196/// This enum captures the possible validation errors for a `ByteRange`.
197#[derive(Debug, Error)]
198pub enum ByteRangeError {
199    #[error("The start byte must be less than or equal to the end byte \n start: {0}, end: {1}")]
200    InvalidRange(i64, i64),
201    #[error("The start byte must be non-negative: \n start {0}")]
202    NegativeStart(i64),
203}
204
205/// A struct representing a byte range.
206///
207/// `ByteRange` is used to define a range of bytes, typically for operations such as
208/// downloading a specific portion of an object from S3. It includes validation to ensure
209/// the byte range is valid, with the start byte less than or equal to the end byte and
210/// both bytes being non-negative.
211#[derive(Debug, Clone, Copy)]
212pub struct ByteRange(i64, i64);
213
214impl TryFrom<(i64, i64)> for ByteRange {
215    type Error = ByteRangeError;
216
217    fn try_from(value: (i64, i64)) -> Result<Self, Self::Error> {
218        let (start, end) = value;
219
220        if start < 0 {
221            Err(ByteRangeError::NegativeStart(start))
222        } else if start > end {
223            Err(ByteRangeError::InvalidRange(start, end))
224        } else {
225            Ok(ByteRange(start, end))
226        }
227    }
228}
229
230impl ByteRange {
231    /// Generates a byte range string for S3 operations.
232    ///
233    /// # Returns
234    ///
235    /// A string representing the byte range.
236    ///
237    /// # Examples
238    ///
239    /// ```compile_fail
240    /// let range = ByteRange::try_from((0, 499)).unwrap();
241    /// assert_eq!(range.to_string(), "bytes=0-499");
242    ///
243    /// let range = ByteRange::try_from((500, 999)).unwrap();
244    /// assert_eq!(range.as_string(), "bytes=500-999");
245    /// ```
246    pub fn as_string(&self) -> String {
247        let ByteRange(start, end) = self;
248        format!("bytes={}-{}", start, end)
249    }
250}
251
252/// Custom error types for S3 multipart copy operations.
253#[derive(Debug, Error)]
254pub enum S3MultipartCopierError {
255    #[error("Missing multipart upload id")]
256    MissingUploadId,
257    #[error("Missing copy part result")]
258    MissingCopyPartResult,
259    #[error("Missing content length")]
260    MissingContentLength,
261    #[error(transparent)]
262    RangeError(#[from] RangeError<i64>),
263    #[error(transparent)]
264    FloatError(#[from] FloatError<f64>),
265    #[error(transparent)]
266    TryFromIntError(#[from] TryFromIntError),
267    #[error(transparent)]
268    SourceSize(#[from] SourceSizeError),
269    #[error("PartSize larger than SourceSize \n Atomic copy should be use. part_size : {part_size}, source_size : {source_size}")]
270    PartSizeGreaterThanOrEqualSource { part_size: i64, source_size: i64 },
271    #[error("Can not perform multipart copy with source size 0")]
272    MultipartCopySourceSizeZero,
273    #[error(transparent)]
274    ByteRangeError(#[from] ByteRangeError),
275    #[error(transparent)]
276    S3Error(Box<dyn StdError + Send + Sync>),
277}
278
279impl<E: StdError + Send + Sync + 'static> From<SdkError<E>> for S3MultipartCopierError {
280    fn from(value: SdkError<E>) -> Self {
281        Self::S3Error(Box::new(value))
282    }
283}
284
285/// A struct representing the parameters required for copying a part of an S3 object.
286#[derive(Debug, TypedBuilder)]
287struct CopyUploadPart<'a> {
288    src: &'a S3Object,
289    dst: &'a S3Object,
290    upload_id: &'a str,
291    part_number: i32,
292    byte_range: ByteRange,
293}
294
295/// A struct to handle S3 multipart copy operations.
296///
297/// `S3MultipartCopier` facilitates copying large objects in S3 by breaking them into
298/// smaller parts and uploading them in parallel. This is particularly useful for objects
299/// larger than 5 GB, as S3's single-part copy operation is limited to this size. If the
300/// source file is smaller than the part size, an atomic copy will be used instead, which
301/// involves calling the S3 copy API to perform the copy operation in a single request.
302///
303/// # Fields
304///
305/// - `client`: An `Arc`-wrapped S3 `Client` used to perform the copy operations.
306/// - `part_size`: The size of each part in bytes. Defaults to `DEFAULT_COPY_PART_SIZE` (50 MiB).
307/// - `max_concurrent_uploads`: The maximum number of parts to upload concurrently.
308/// - `source`: The `S3Object` representing the source object to copy.
309/// - `destination`: The `S3Object` representing the destination object.
310///
311/// # Example
312///
313/// ```no_run
314/// # tokio_test::block_on(async {
315/// use aws_sdk_s3::Client;
316/// use std::sync::Arc;
317/// use cobalt_aws::s3::{S3MultipartCopier, S3Object, PartSize};
318///
319/// let shared_config = cobalt_aws::config::load_from_env().await.unwrap();
320/// let client = Arc::new(Client::new(&shared_config));
321/// let source = S3Object::new("source-bucket", "source-key");
322/// let destination = S3Object::new("destination-bucket", "destination-key");
323///
324/// let copier = S3MultipartCopier::builder()
325///     .client(client)
326///     .part_size(PartSize::try_from(50 * 1024 * 1024).unwrap()) // 50 MiB
327///     .max_concurrent_uploads(4)
328///     .source(source)
329///     .destination(destination)
330///     .build();
331///
332/// copier.send().await.unwrap();
333/// # })
334/// ```
335///
336/// # Note
337///
338/// Ensure that the `part_size` is appropriate for your use case. While 50 MiB is a good default,
339/// it might not be suitable for all operations. Adjust the part size based on your specific
340/// requirements and constraints. Additionally, if the source file is smaller than the part size,
341/// an atomic copy will be used instead of a multipart copy. An atomic copy involves calling the
342/// S3 copy API to perform the copy operation in a single request.
343
344#[derive(Debug, TypedBuilder)]
345pub struct S3MultipartCopier {
346    client: Arc<Client>,
347    #[builder(default=PartSize::default())]
348    part_size: PartSize,
349    max_concurrent_uploads: usize,
350    source: S3Object,
351    destination: S3Object,
352}
353
354impl S3MultipartCopier {
355    /// Initiates a multipart upload to the specified S3 bucket.
356    ///
357    /// # Arguments
358    ///
359    /// * `bucket` - The name of the destination S3 bucket.
360    /// * `key` - The key of the destination object.
361    ///
362    /// # Returns
363    ///
364    /// The output of the multipart upload initiation.
365    #[instrument(skip(self))]
366    async fn initiate_multipart_upload(
367        &self,
368    ) -> Result<CreateMultipartUploadOutput, SdkError<CreateMultipartUploadError>> {
369        self.client
370            .create_multipart_upload()
371            .bucket(&self.destination.bucket)
372            .key(&self.destination.key)
373            .send()
374            .await
375    }
376
377    fn copy_source(object: &S3Object) -> String {
378        format!("{}/{}", object.bucket, object.key)
379    }
380
381    /// Uploads a part of the source object to the destination as part of the multipart upload.
382    ///
383    /// # Arguments
384    ///
385    /// * `part` - The `CopyUploadPart` containing the parameters for the upload.
386    ///
387    /// # Returns
388    ///
389    /// The completed part containing the ETag and part number.
390    #[instrument(skip(self))]
391    async fn upload_part_copy(
392        &self,
393        part: CopyUploadPart<'_>,
394    ) -> Result<CompletedPart, S3MultipartCopierError> {
395        let copy_source = S3MultipartCopier::copy_source(part.src);
396
397        let response = self
398            .client
399            .upload_part_copy()
400            .bucket(&part.dst.bucket)
401            .key(&part.dst.key)
402            .part_number(part.part_number)
403            .upload_id(part.upload_id)
404            .copy_source(copy_source)
405            .copy_source_range(part.byte_range.as_string())
406            .send()
407            .await?;
408
409        Ok(CompletedPart::builder()
410            .set_e_tag(
411                response
412                    .copy_part_result
413                    .ok_or(S3MultipartCopierError::MissingCopyPartResult)?
414                    .e_tag,
415            )
416            .part_number(part.part_number)
417            .build())
418    }
419
420    /// Completes the multipart upload by combining all parts.
421    ///
422    /// # Arguments
423    ///
424    /// * `upload_id` - The upload ID of the multipart upload.
425    /// * `parts` - A vector of completed parts.
426    ///
427    /// # Returns
428    ///
429    /// An empty result indicating success.
430    #[instrument(skip(self))]
431    async fn complete_multipart_upload(
432        &self,
433        upload_id: &str,
434        mut parts: Vec<CompletedPart>,
435    ) -> Result<(), SdkError<CompleteMultipartUploadError>> {
436        parts.sort_by_key(|part| part.part_number);
437        let completed_multipart_upload = CompletedMultipartUpload::builder()
438            .set_parts(Some(parts))
439            .build();
440
441        self.client
442            .complete_multipart_upload()
443            .bucket(&self.destination.bucket)
444            .key(&self.destination.key)
445            .upload_id(upload_id)
446            .multipart_upload(completed_multipart_upload)
447            .send()
448            .await?;
449
450        Ok(())
451    }
452
453    /// Performs a multipart copy of a large object from the source bucket to the destination bucket.
454    ///
455    /// # Returns
456    ///
457    /// An empty result indicating success.
458    #[instrument(skip(self))]
459    pub async fn send(&self) -> Result<(), S3MultipartCopierError> {
460        tracing::info!("Starting multipart copy");
461        let source_size =
462            get_source_size(&self.client, &self.source.bucket, &self.source.key).await?;
463
464        tracing::info!(
465            source_size = source_size.as_ref(),
466            part_size = self.part_size.as_ref(),
467        );
468
469        //If part size is larger than or equal to source size
470        //a atomic copy is faster and cheaper.
471        if self.part_size.as_ref() >= source_size.as_ref() {
472            tracing::info!("Part size is greater than or equal to source size, using atomic copy");
473            self.atomic_copy().await
474        } else {
475            tracing::info!("Source size is larger than part size, using multipart copy");
476            self.multipart_copy(&source_size).await
477        }
478    }
479
480    async fn atomic_copy(&self) -> Result<(), S3MultipartCopierError> {
481        let copy_source = S3MultipartCopier::copy_source(&self.source);
482        self.client
483            .copy_object()
484            .copy_source(copy_source)
485            .bucket(&self.destination.bucket)
486            .key(&self.destination.key)
487            .send()
488            .await?;
489        Ok(())
490    }
491
492    async fn multipart_copy(&self, source_size: &SourceSize) -> Result<(), S3MultipartCopierError> {
493        if self.part_size.as_ref() > source_size.as_ref() {
494            return Err(S3MultipartCopierError::PartSizeGreaterThanOrEqualSource {
495                part_size: *self.part_size.as_ref(),
496                source_size: *source_size.as_ref(),
497            });
498        }
499
500        let create_multipart_upload = self.initiate_multipart_upload().await?;
501        let upload_id = create_multipart_upload
502            .upload_id()
503            .ok_or(S3MultipartCopierError::MissingUploadId)?;
504
505        let parts = futures::stream::iter(Self::byte_ranges(source_size, &self.part_size))
506            .map_ok(|(part_number, byte_range)| {
507                let source = &self.source;
508                let destination = &self.destination;
509                let upload_id = upload_id.to_string();
510
511                async move {
512                    tracing::info!(byte_range = ?byte_range);
513
514                    let part = CopyUploadPart::builder()
515                        .src(source)
516                        .dst(destination)
517                        .upload_id(&upload_id)
518                        .part_number(i32::try_from(part_number)?)
519                        .byte_range(byte_range)
520                        .build();
521                    tracing::debug!(part = ?part, "Copying");
522                    self.upload_part_copy(part).await
523                }
524            })
525            .try_buffer_unordered(self.max_concurrent_uploads);
526
527        let completed_parts: Vec<CompletedPart> = parts.try_collect().await?;
528
529        tracing::info!(upload_id = upload_id, "All parts completed");
530        self.complete_multipart_upload(upload_id, completed_parts)
531            .await?;
532
533        tracing::info!("MultipartCopy completed");
534        Ok(())
535    }
536
537    fn byte_ranges<'a>(
538        source_size: &'a SourceSize,
539        part_size: &'a PartSize,
540    ) -> impl Iterator<Item = Result<(i64, ByteRange), S3MultipartCopierError>> + 'a {
541        if *source_size.as_ref() == 0 {
542            Either::Left(std::iter::once(Err(
543                S3MultipartCopierError::MultipartCopySourceSizeZero,
544            )))
545        } else {
546            let part_count = match S3MultipartCopier::part_count(source_size, part_size) {
547                Ok(count) => count,
548                Err(e) => return Either::Left(std::iter::once(Err(e))),
549            };
550            Either::Right((1..=part_count).map(move |part_number| {
551                let part_size = *part_size.as_ref();
552                let source_size = *source_size.as_ref();
553
554                let byte_range_start = (part_number - 1) * part_size;
555                let byte_range_end = std::cmp::min(part_number * part_size - 1, source_size - 1);
556
557                let byte_range = ByteRange::try_from((byte_range_start, byte_range_end))?;
558                Ok((part_number, byte_range))
559            }))
560        }
561    }
562
563    fn part_count(
564        source_size: &SourceSize,
565        part_size: &PartSize,
566    ) -> Result<i64, S3MultipartCopierError> {
567        let source_size = *source_size.as_ref();
568        let part_size = *part_size.as_ref();
569
570        if source_size == 0 {
571            return Err(S3MultipartCopierError::MultipartCopySourceSizeZero);
572        }
573
574        Ok(((f64::value_from(source_size)? / f64::value_from(part_size)?).ceil()).approx()?)
575    }
576}
577
578#[cfg(test)]
579pub mod arbitrary {
580    use derive_more::{AsRef, From, Into};
581    use proptest::prelude::*;
582
583    use super::{
584        PartSize, SourceSize, MAX_PART_SIZE, MAX_SOURCE_SIZE, MIN_PART_SIZE, MIN_SOURCE_SIZE,
585    };
586
587    impl Arbitrary for PartSize {
588        type Parameters = ();
589        type Strategy = BoxedStrategy<Self>;
590
591        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
592            (MIN_PART_SIZE..=MAX_PART_SIZE)
593                .prop_map(|size| PartSize::try_from(size).unwrap())
594                .boxed()
595        }
596    }
597
598    impl Arbitrary for SourceSize {
599        type Parameters = ();
600        type Strategy = BoxedStrategy<Self>;
601
602        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
603            (MIN_SOURCE_SIZE..=MAX_SOURCE_SIZE)
604                .prop_map(|size| SourceSize::try_from(size).unwrap())
605                .boxed()
606        }
607    }
608
609    #[derive(Debug, Clone, PartialEq, Eq, AsRef, Into, From)]
610    pub struct NonZeroSourceSize(SourceSize);
611
612    // Arbitrary implementation for NonZeroSourceSize
613    impl Arbitrary for NonZeroSourceSize {
614        type Parameters = ();
615        type Strategy = BoxedStrategy<Self>;
616
617        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
618            (1..=MAX_SOURCE_SIZE)
619                .prop_map(|size| SourceSize::try_from(size).unwrap())
620                .prop_map(NonZeroSourceSize)
621                .boxed()
622        }
623    }
624}
625
626#[cfg(test)]
627mod tests {
628
629    use self::arbitrary::NonZeroSourceSize;
630
631    use super::*;
632    use crate::localstack;
633    use crate::s3::test::*;
634    use crate::s3::{AsyncMultipartUpload, S3Object};
635    use ::function_name::named;
636    use anyhow::Result;
637    use aws_sdk_s3::Client;
638    use bytesize::MIB;
639    use futures::prelude::*;
640    use proptest::{prop_assert, prop_assert_eq};
641    use rand::Rng;
642    use thiserror::Error;
643
644    use test_strategy::proptest;
645
646    //Wrapper to allow anyhow to be used a std::error::Error
647    #[derive(Debug, Error)]
648    #[error(transparent)]
649    pub struct CustomError(#[from] anyhow::Error);
650
651    // *** Integration tests *** //
652    //Integration tests should be in src/tests but there is tight coupling with
653    //localstack which makes it hard to migrate away from this structure.
654    async fn localstack_test_client() -> Client {
655        localstack::test_utils::wait_for_localstack().await;
656        let shared_config = crate::config::load_from_env().await.unwrap();
657        let builder = aws_sdk_s3::config::Builder::from(&shared_config)
658            .force_path_style(true)
659            .build();
660        Client::from_conf(builder)
661    }
662
663    #[proptest(async = "tokio", cases = 3)]
664    #[named]
665    async fn test_multipart_copy(#[strategy(0_usize..=50*MIB as usize)] upload_size: usize) {
666        let client = Arc::new(localstack_test_client().await);
667        let test_bucket = "test-multipart-bucket";
668        let mut rng = seeded_rng(function_name!());
669        let src_key = gen_random_file_name(&mut rng);
670
671        let result = create_bucket(&client, test_bucket).await;
672        prop_assert!(result.is_ok(), "Error: {result:?}");
673
674        let src = S3Object::new(test_bucket, &src_key);
675        let src_bytes = generate_random_bytes(upload_size, &mut rng);
676
677        let part_size = MIN_PART_SIZE as usize;
678
679        let mut writer = AsyncMultipartUpload::new(&client, &src, part_size, None)
680            .await
681            .map_err(CustomError)?;
682        writer.write_all(&src_bytes).await?;
683        writer.close().await?;
684
685        //prop_assert!(put_result.is_ok(), "Result : {put_result:?}");
686
687        let dst_key = gen_random_file_name(&mut rng);
688
689        let dest = S3Object::new(test_bucket, dst_key);
690        let copyier = S3MultipartCopier::builder()
691            .client(client.clone())
692            .source(src)
693            .destination(dest.clone())
694            .part_size((5 * MIB as i64).try_into()?)
695            .max_concurrent_uploads(100)
696            .build();
697        copyier.send().await?;
698
699        let copied_bytes = fetch_bytes(&client, &dest).await.map_err(CustomError)?;
700        prop_assert_eq!(src_bytes, copied_bytes);
701    }
702
703    #[tokio::test]
704    #[named]
705    async fn test_zero_size_multipart_copy() -> Result<()> {
706        let client = Arc::new(localstack_test_client().await);
707        let test_bucket = "test-multipart-bucket";
708        let mut rng = seeded_rng(function_name!());
709        let src_key = gen_random_file_name(&mut rng);
710
711        create_bucket(&client, test_bucket).await?;
712
713        let src = S3Object::new(test_bucket, &src_key);
714        client
715            .put_object()
716            .bucket(test_bucket)
717            .key(src_key)
718            .body(Vec::default().into())
719            .send()
720            .await?;
721
722        let dst_key = gen_random_file_name(&mut rng);
723
724        let dest = S3Object::new(test_bucket, dst_key);
725        let copyier = S3MultipartCopier::builder()
726            .client(client.clone())
727            .source(src)
728            .destination(dest.clone())
729            .part_size((5 * MIB as i64).try_into()?)
730            .max_concurrent_uploads(2)
731            .build();
732        copyier.send().await?;
733        let copied_bytes = fetch_bytes(&client, &dest).await?;
734        assert_eq!(copied_bytes.len(), 0);
735
736        Ok(())
737    }
738
739    fn generate_random_bytes(length: usize, rng: &mut impl Rng) -> Vec<u8> {
740        (0..length).map(|_| rng.random()).collect()
741    }
742
743    #[proptest]
744    fn test_part_count_valid(non_zero_source: NonZeroSourceSize, part_size: PartSize) {
745        let source_size = non_zero_source.into();
746        let count = S3MultipartCopier::part_count(&source_size, &part_size)?;
747
748        // Assert that the result is Ok and the part count is positive
749        prop_assert!(count >= 0, "Expected count greater than 0");
750        // Validate the part count is correct
751        let expected_count = ((source_size.0 as f64) / (part_size.0 as f64)).ceil() as i64;
752        prop_assert_eq!(count, expected_count);
753    }
754
755    #[proptest]
756    fn test_part_count_small_source(part_size: PartSize, #[strategy(1_i64..=10)] source: i64) {
757        let source_size = SourceSize(source);
758
759        let count = S3MultipartCopier::part_count(&source_size, &part_size)?;
760        prop_assert_eq!(count, 1, "Expected a part count of 1");
761    }
762
763    #[proptest]
764    fn test_part_count_large_source(part_size: PartSize) {
765        let source_size = SourceSize(MAX_SOURCE_SIZE); // Very large source size
766
767        let count = S3MultipartCopier::part_count(&source_size, &part_size)?;
768
769        // Assert that the result is Ok and the part count is positive
770        prop_assert!(count >= 0, "Expected count greater than 0");
771
772        let expected_count = ((source_size.0 as f64) / (part_size.0 as f64)).ceil() as i64;
773        prop_assert_eq!(count, expected_count);
774    }
775
776    #[proptest]
777    fn test_part_count_error_on_zero_source_size(part_size: PartSize) {
778        let source_size = SourceSize(0); // Zero source size
779
780        let result = S3MultipartCopier::part_count(&source_size, &part_size);
781
782        // Assert that the result is an error and the error is MultipartCopySourceSizeZero
783        prop_assert!(result.is_err());
784
785        prop_assert!(matches!(
786            result.unwrap_err(),
787            S3MultipartCopierError::MultipartCopySourceSizeZero
788        ));
789    }
790
791    #[proptest]
792    fn test_source_size_within_limits(#[strategy(MIN_SOURCE_SIZE..=MAX_SOURCE_SIZE)] value: i64) {
793        let source_size = SourceSize::try_from(value)?;
794        prop_assert_eq!(source_size.as_ref(), &value);
795    }
796
797    #[proptest]
798    fn test_source_size_too_small(#[strategy(i64::MIN..MIN_SOURCE_SIZE)] value: i64) {
799        let source_size = SourceSize::try_from(value);
800        prop_assert!(source_size.is_err());
801        prop_assert!(
802            matches!(source_size.unwrap_err(), SourceSizeError::TooSmall(v) if v == value)
803        );
804    }
805
806    #[proptest]
807    fn test_source_size_too_large(#[strategy((MAX_SOURCE_SIZE + 1)..=i64::MAX)] value: i64) {
808        let source_size = SourceSize::try_from(value);
809        prop_assert!(source_size.is_err());
810        prop_assert!(
811            matches!(source_size.unwrap_err(), SourceSizeError::TooLarge(v) if v == value)
812        );
813    }
814
815    // PartSize
816    #[proptest]
817    fn test_part_size_within_limits(#[strategy(MIN_PART_SIZE..=MAX_PART_SIZE)] value: i64) {
818        let part_size = PartSize::try_from(value)?;
819        prop_assert_eq!(part_size.as_ref(), &value);
820    }
821
822    #[proptest]
823    fn test_part_size_too_small(#[strategy(i64::MIN..MIN_PART_SIZE)] value: i64) {
824        let part_size = PartSize::try_from(value);
825        prop_assert!(part_size.is_err());
826        prop_assert!(matches!(part_size.unwrap_err(), PartSizeError::TooSmall(v) if v == value));
827    }
828
829    #[proptest]
830    fn test_part_size_too_large(#[strategy((MAX_PART_SIZE + 1)..=i64::MAX)] value: i64) {
831        let part_size = PartSize::try_from(value);
832        prop_assert!(part_size.is_err());
833        prop_assert!(matches!(part_size.unwrap_err(), PartSizeError::TooLarge(v) if v == value));
834    }
835
836    //byte range
837    #[proptest]
838    fn valid_byte_range(#[strategy(0..i64::MAX)] start: i64, #[strategy(0..i64::MAX)] end: i64) {
839        if start <= end {
840            let range = ByteRange::try_from((start, end))?;
841            prop_assert_eq!(range.0, start);
842            prop_assert_eq!(range.1, end);
843        } else {
844            let range = ByteRange::try_from((start, end));
845            prop_assert!(
846                matches!(range, Err(ByteRangeError::InvalidRange(s, e)) if s == start && e == end)
847            );
848        }
849    }
850
851    #[proptest]
852    fn invalid_negative_start_byte_range(
853        #[strategy(i64::MIN..0)] start: i64,
854        #[strategy(0..i64::MAX)] end: i64,
855    ) {
856        let range = ByteRange::try_from((start, end));
857        prop_assert!(matches!(range, Err(ByteRangeError::NegativeStart(s)) if s == start));
858    }
859
860    #[proptest]
861    fn invalid_byte_range_start_greater_than_end(
862        #[strategy(0..i64::MAX)] start: i64,
863        #[strategy(0..i64::MAX)] end: i64,
864    ) {
865        if start > end {
866            let range = ByteRange::try_from((start, end));
867            prop_assert!(
868                matches!(range, Err(ByteRangeError::InvalidRange(s, e)) if s == start && e == end)
869            );
870        }
871    }
872
873    //bytes_ranges
874    #[proptest]
875    fn test_byte_ranges_valid(part_size: PartSize, non_zero_source: NonZeroSourceSize) {
876        let source_size = non_zero_source.into();
877
878        let result: Vec<_> = S3MultipartCopier::byte_ranges(&source_size, &part_size).collect();
879
880        for (i, item) in result.iter().enumerate() {
881            let (part_number, ByteRange(start, end)) = item.as_ref()?;
882            prop_assert_eq!(*part_number as usize, i + 1);
883            prop_assert!(start >= &0);
884            prop_assert!(end >= start);
885            prop_assert!(end < source_size.as_ref());
886        }
887        let mut expected_start = 0;
888        for item in &result {
889            let ByteRange(start, end) = item.as_ref()?.1;
890            prop_assert_eq!(start, expected_start);
891            expected_start = end + 1;
892        }
893        prop_assert_eq!(expected_start, *source_size.as_ref());
894    }
895
896    #[proptest]
897    fn test_byte_ranges_zero_source_size(part_size: PartSize) {
898        let source_size = SourceSize(0);
899
900        let result: Vec<_> = S3MultipartCopier::byte_ranges(&source_size, &part_size).collect();
901
902        prop_assert!(result.len() == 1);
903        let err = result[0].as_ref().err().unwrap();
904        prop_assert!(matches!(
905            err,
906            S3MultipartCopierError::MultipartCopySourceSizeZero
907        ));
908    }
909
910    #[proptest]
911    fn test_byte_ranges_large_source(
912        part_size: PartSize,
913        #[strategy(1_000_000_000_i64..10_000_000_000_i64)] source: i64,
914    ) {
915        let source_size = SourceSize(source);
916        let result: Vec<_> = S3MultipartCopier::byte_ranges(&source_size, &part_size).collect();
917
918        for (i, item) in result.iter().enumerate() {
919            prop_assert!(item.is_ok(), "Error {:?}", item);
920            let (part_number, ByteRange(start, end)) = item.as_ref()?;
921            prop_assert_eq!(*part_number as usize, i + 1);
922            prop_assert!(start >= &0);
923            prop_assert!(end >= start);
924            prop_assert!(end < source_size.as_ref());
925        }
926        let mut expected_start = 0;
927        for item in &result {
928            let ByteRange(start, end) = item.as_ref()?.1;
929            prop_assert_eq!(start, expected_start);
930            expected_start = end + 1;
931        }
932        prop_assert_eq!(expected_start, *source_size.as_ref());
933    }
934}