use anyhow::{Context, Result};
use async_channel::Sender;
use aws_sdk_s3::operation::get_object::GetObjectOutput;
use aws_sdk_s3::primitives::ByteStream;
use aws_sdk_s3::types::ChecksumAlgorithm;
use aws_smithy_types::DateTime;
use tokio::io::AsyncReadExt;
use crate::Config;
use crate::storage::Storage;
use crate::storage::checksum::AdditionalChecksum;
use crate::transfer::TransferOutcome;
use crate::types::token::PipelineCancellationToken;
use crate::types::{ObjectChecksum, SyncStatistics};
async fn probe_up_to<R: tokio::io::AsyncRead + Unpin + ?Sized>(
reader: &mut R,
limit: usize,
) -> Result<Vec<u8>> {
let mut buf = Vec::with_capacity(limit);
(&mut *reader)
.take(limit as u64)
.read_to_end(&mut buf)
.await
.context("probe_up_to: failed to read from reader")?;
Ok(buf)
}
pub async fn transfer(
config: &Config,
target: Storage,
target_key: &str,
mut reader: impl tokio::io::AsyncRead + Unpin + Send + 'static,
cancellation_token: PipelineCancellationToken,
stats_sender: Sender<SyncStatistics>,
) -> Result<TransferOutcome> {
if cancellation_token.is_cancelled() {
return Ok(TransferOutcome::default());
}
let threshold = config.transfer_config.multipart_threshold as usize;
let initial = probe_up_to(&mut reader, threshold).await?;
if initial.len() < threshold {
return transfer_buffered(
config,
target,
target_key,
initial,
cancellation_token,
stats_sender,
)
.await;
}
transfer_streaming(
config,
target,
target_key,
initial,
reader,
cancellation_token,
stats_sender,
)
.await
}
async fn transfer_streaming(
config: &Config,
target: Storage,
target_key: &str,
initial: Vec<u8>,
reader: impl tokio::io::AsyncRead + Unpin + Send + 'static,
_cancellation_token: PipelineCancellationToken,
stats_sender: Sender<SyncStatistics>,
) -> Result<TransferOutcome> {
let chained: Box<dyn tokio::io::AsyncRead + Send + Unpin> =
Box::new(std::io::Cursor::new(initial).chain(reader));
let tagging = if config.disable_tagging {
None
} else {
config.tagging.clone()
};
let object_checksum = ObjectChecksum {
key: target_key.to_string(),
version_id: None,
checksum_algorithm: config.additional_checksum_algorithm.clone(),
checksum_type: None,
object_parts: None,
final_checksum: None,
};
let _put_object_output = target
.put_object_stream(target_key, chained, tagging, Some(object_checksum), None)
.await
.context(format!("failed to stream to target: {target_key}"))?;
let _ = stats_sender
.send(SyncStatistics::SyncComplete {
key: target_key.to_string(),
})
.await;
Ok(TransferOutcome::default())
}
async fn transfer_buffered(
config: &Config,
target: Storage,
target_key: &str,
buffer: Vec<u8>,
_cancellation_token: PipelineCancellationToken,
stats_sender: Sender<SyncStatistics>,
) -> Result<TransferOutcome> {
let target_clone = dyn_clone::clone_box(&*target);
let source_size = buffer.len() as u64;
let source_additional_checksum =
config
.additional_checksum_algorithm
.clone()
.map(|algorithm| {
compute_source_checksum(
&buffer,
algorithm,
config.transfer_config.multipart_chunksize as usize,
config.transfer_config.multipart_threshold as usize,
config.full_object_checksum,
)
});
let (checksum_sha256, checksum_sha1, checksum_crc32, checksum_crc32_c, checksum_crc64_nvme) =
match config.additional_checksum_algorithm.as_ref() {
Some(ChecksumAlgorithm::Sha256) => {
(source_additional_checksum.clone(), None, None, None, None)
}
Some(ChecksumAlgorithm::Sha1) => {
(None, source_additional_checksum.clone(), None, None, None)
}
Some(ChecksumAlgorithm::Crc32) => {
(None, None, source_additional_checksum.clone(), None, None)
}
Some(ChecksumAlgorithm::Crc32C) => {
(None, None, None, source_additional_checksum.clone(), None)
}
Some(ChecksumAlgorithm::Crc64Nvme) => {
(None, None, None, None, source_additional_checksum.clone())
}
_ => (None, None, None, None, None),
};
let byte_stream = ByteStream::from(buffer);
let get_object_output = GetObjectOutput::builder()
.set_body(Some(byte_stream))
.set_content_length(Some(source_size as i64))
.set_content_type(config.content_type.clone())
.set_last_modified(Some(DateTime::from_secs(chrono::Utc::now().timestamp())))
.set_checksum_sha256(checksum_sha256)
.set_checksum_sha1(checksum_sha1)
.set_checksum_crc32(checksum_crc32)
.set_checksum_crc32_c(checksum_crc32_c)
.set_checksum_crc64_nvme(checksum_crc64_nvme)
.build();
let tagging = if config.disable_tagging {
None
} else {
config.tagging.clone()
};
let object_checksum = ObjectChecksum {
key: target_key.to_string(),
version_id: None,
checksum_algorithm: config.additional_checksum_algorithm.clone(),
checksum_type: None,
object_parts: None,
final_checksum: None,
};
let _put_object_output = target
.put_object(
target_key,
target_clone,
target_key,
source_size,
source_additional_checksum,
get_object_output,
tagging,
Some(object_checksum),
None,
)
.await
.context(format!("failed to upload to target: {target_key}"))?;
let _ = stats_sender
.send(SyncStatistics::SyncComplete {
key: target_key.to_string(),
})
.await;
Ok(TransferOutcome::default())
}
fn compute_source_checksum(
buffer: &[u8],
algorithm: ChecksumAlgorithm,
multipart_chunksize: usize,
multipart_threshold: usize,
full_object_checksum: bool,
) -> String {
let mut checksum = AdditionalChecksum::new(algorithm, full_object_checksum);
if buffer.len() < multipart_threshold {
checksum.update(buffer);
return checksum.finalize();
}
let mut offset = 0;
while offset < buffer.len() {
let end = std::cmp::min(offset + multipart_chunksize, buffer.len());
checksum.update(&buffer[offset..end]);
let _ = checksum.finalize();
offset = end;
}
checksum.finalize_all()
}
#[cfg(test)]
mod probe_tests {
use super::probe_up_to;
use std::io::Cursor;
#[tokio::test]
async fn returns_all_bytes_when_reader_smaller_than_limit() {
let mut reader = Cursor::new(vec![1u8; 30]);
let buf = probe_up_to(&mut reader, 100).await.unwrap();
assert_eq!(buf.len(), 30);
assert_eq!(buf, vec![1u8; 30]);
}
#[tokio::test]
async fn returns_exactly_limit_bytes_when_reader_larger() {
let mut reader = Cursor::new(vec![2u8; 200]);
let buf = probe_up_to(&mut reader, 100).await.unwrap();
assert_eq!(buf.len(), 100);
assert_eq!(buf, vec![2u8; 100]);
}
#[tokio::test]
async fn returns_limit_bytes_when_reader_exactly_limit() {
let mut reader = Cursor::new(vec![3u8; 100]);
let buf = probe_up_to(&mut reader, 100).await.unwrap();
assert_eq!(buf.len(), 100);
assert_eq!(buf, vec![3u8; 100]);
}
#[tokio::test]
async fn returns_empty_for_empty_reader() {
let mut reader = Cursor::new(Vec::<u8>::new());
let buf = probe_up_to(&mut reader, 100).await.unwrap();
assert!(buf.is_empty());
}
#[tokio::test]
async fn leaves_remaining_bytes_in_reader() {
let data = vec![5u8; 50];
let mut reader = Cursor::new(data);
let probed = probe_up_to(&mut reader, 20).await.unwrap();
assert_eq!(probed, vec![5u8; 20]);
let mut rest = Vec::new();
tokio::io::AsyncReadExt::read_to_end(&mut reader, &mut rest)
.await
.unwrap();
assert_eq!(rest.len(), 30);
}
}
#[cfg(test)]
mod checksum_invariant_tests {
use super::compute_source_checksum;
use crate::storage::checksum::AdditionalChecksum;
use aws_sdk_s3::types::ChecksumAlgorithm;
fn all_algorithms() -> Vec<ChecksumAlgorithm> {
vec![
ChecksumAlgorithm::Sha256,
ChecksumAlgorithm::Sha1,
ChecksumAlgorithm::Crc32,
ChecksumAlgorithm::Crc32C,
ChecksumAlgorithm::Crc64Nvme,
]
}
fn streaming_checksum(
buffer: &[u8],
algorithm: ChecksumAlgorithm,
multipart_chunksize: usize,
full_object_checksum: bool,
) -> String {
let mut c = AdditionalChecksum::new(algorithm, full_object_checksum);
let mut offset = 0;
while offset < buffer.len() {
let end = std::cmp::min(offset + multipart_chunksize, buffer.len());
c.update(&buffer[offset..end]);
let _ = c.finalize();
offset = end;
}
c.finalize_all()
}
#[test]
fn streaming_matches_buffered_for_multipart_sizes() {
let chunksize = 1024usize;
let threshold = 1024usize;
let buffer = vec![0xABu8; chunksize * 4 + 17];
for algo in all_algorithms() {
let batched =
compute_source_checksum(&buffer, algo.clone(), chunksize, threshold, false);
let streamed = streaming_checksum(&buffer, algo.clone(), chunksize, false);
assert_eq!(
batched, streamed,
"algorithm {:?}: batched vs streamed checksum mismatch",
algo
);
}
}
#[test]
fn streaming_matches_buffered_for_exact_chunksize_multiples() {
let chunksize = 1024usize;
let threshold = 1024usize;
let buffer = vec![0x5Au8; chunksize * 3];
for algo in all_algorithms() {
let batched =
compute_source_checksum(&buffer, algo.clone(), chunksize, threshold, false);
let streamed = streaming_checksum(&buffer, algo.clone(), chunksize, false);
assert_eq!(batched, streamed, "algorithm {:?}", algo);
}
}
#[test]
fn sub_threshold_matches_single_update_finalize() {
let chunksize = 1024usize;
let threshold = 4096usize;
let buffer = vec![0xC3u8; 1500];
for algo in all_algorithms() {
let actual =
compute_source_checksum(&buffer, algo.clone(), chunksize, threshold, false);
let mut expected = AdditionalChecksum::new(algo.clone(), false);
expected.update(&buffer);
assert_eq!(
actual,
expected.finalize(),
"algorithm {:?}: sub-threshold path should match single update+finalize",
algo
);
}
}
#[test]
fn empty_buffer_produces_well_defined_checksum() {
let chunksize = 1024usize;
let threshold = 1024usize;
let buffer: Vec<u8> = Vec::new();
for algo in all_algorithms() {
let actual =
compute_source_checksum(&buffer, algo.clone(), chunksize, threshold, false);
let mut expected = AdditionalChecksum::new(algo.clone(), false);
assert_eq!(actual, expected.finalize(), "algorithm {:?}", algo);
}
}
#[test]
fn threshold_boundary_uses_multipart_path() {
let chunksize = 1024usize;
let threshold = 1024usize;
let buffer = vec![0x11u8; threshold];
for algo in all_algorithms() {
let multipart =
compute_source_checksum(&buffer, algo.clone(), chunksize, threshold, false);
let mut single = AdditionalChecksum::new(algo.clone(), false);
single.update(&buffer);
let single_part = single.finalize();
if !matches!(algo, ChecksumAlgorithm::Crc64Nvme) {
assert_ne!(
multipart, single_part,
"algorithm {:?}: threshold boundary should take multipart path",
algo
);
}
}
}
#[test]
fn full_object_checksum_flag_is_threaded_through() {
let chunksize = 1024usize;
let threshold = 1024usize;
let buffer = vec![0x77u8; chunksize * 2 + 5];
for algo in [ChecksumAlgorithm::Crc32, ChecksumAlgorithm::Crc32C] {
let composite =
compute_source_checksum(&buffer, algo.clone(), chunksize, threshold, false);
let full_object =
compute_source_checksum(&buffer, algo.clone(), chunksize, threshold, true);
assert_ne!(
composite, full_object,
"algorithm {:?}: full_object_checksum flag should change result",
algo
);
}
}
}