use crate::core::config::Config;
use anyhow::{Context, Result};
use aws_sdk_s3::Client;
use aws_sdk_s3::config::{Credentials, SharedCredentialsProvider};
use aws_sdk_s3::error::ProvideErrorMetadata;
use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart, Delete, ObjectIdentifier};
use std::path::Path;
use tokio::fs::File;
use tokio::io::{AsyncReadExt, AsyncSeekExt};
use crate::coordinator::ProgressInfo;
use crate::db;
use crate::utils::{lock_async_mutex, lock_mutex};
use rusqlite::Connection;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use tokio::sync::Mutex as AsyncMutex;
use tracing::{error, info};
use base64::{Engine as _, engine::general_purpose};
use sha2::{Digest, Sha256};
pub struct TaskGuard(pub tokio::task::JoinHandle<()>);
impl Drop for TaskGuard {
fn drop(&mut self) {
self.0.abort();
}
}
#[derive(Debug, Clone)]
pub struct S3Object {
pub key: String,
pub name: String, pub size: i64,
pub last_modified: String,
pub is_dir: bool,
pub is_parent: bool,
}
#[derive(Clone)]
pub struct Uploader {}
impl Uploader {
pub fn new(_config: &Config) -> Self {
Self {}
}
async fn create_client(config: &Config) -> Result<(Client, String)> {
let bucket = config.s3_bucket.as_deref().unwrap_or_default();
let region = config.s3_region.as_deref().unwrap_or("us-east-1").trim();
let endpoint = config.s3_endpoint.as_deref();
let access_key = config.s3_access_key.as_deref().map(|s| s.trim());
let secret_key = config.s3_secret_key.as_deref().map(|s| s.trim());
if bucket.is_empty() {
return Err(anyhow::anyhow!("S3 bucket not configured"));
}
#[allow(deprecated)]
let mut config_loader =
aws_config::from_env().region(aws_config::Region::new(region.to_string()));
if let (Some(ak), Some(sk)) = (access_key, secret_key) {
let creds = Credentials::new(ak.to_string(), sk.to_string(), None, None, "static");
config_loader =
config_loader.credentials_provider(SharedCredentialsProvider::new(creds));
} else if access_key.is_some() || secret_key.is_some() {
return Err(anyhow::anyhow!(
"S3 Credentials incomplete: Both Access Key and Secret Key must be provided."
));
}
let sdk_config = config_loader.load().await;
let mut client_builder = aws_sdk_s3::config::Builder::from(&sdk_config);
if let Some(endpoint) = endpoint {
client_builder = client_builder.endpoint_url(endpoint).force_path_style(true);
}
let client = Client::from_conf(client_builder.build());
Ok((client, bucket.to_string()))
}
pub async fn list_bucket_contents(
config: &Config,
subdir: Option<&str>,
) -> Result<Vec<S3Object>> {
let (client, bucket) = Self::create_client(config).await?;
let base_prefix = config.s3_prefix.as_deref().unwrap_or("");
let prefix = if let Some(sub) = subdir {
format!("{}{}", base_prefix, sub)
} else {
base_prefix.to_string()
};
let mut objects = Vec::new();
let mut continuation_token = None;
loop {
let mut req = client.list_objects_v2().bucket(&bucket).delimiter("/");
if !prefix.is_empty() {
req = req.prefix(&prefix);
}
if let Some(token) = continuation_token {
req = req.continuation_token(token);
}
let resp = req.send().await.context("Failed to list objects")?;
let is_truncated = resp.is_truncated().unwrap_or(false);
let next_token = resp.next_continuation_token.clone();
if let Some(common_prefixes) = resp.common_prefixes {
for cp in common_prefixes {
if let Some(key) = cp.prefix {
let name = if key.starts_with(&prefix) {
key[prefix.len()..].to_string()
} else {
key.clone()
};
objects.push(S3Object {
key,
name,
size: 0,
last_modified: String::new(),
is_dir: true,
is_parent: false,
});
}
}
}
if let Some(contents) = resp.contents {
for obj in contents {
let key = obj.key().unwrap_or("").to_string();
if key == prefix {
continue;
}
let size = obj.size().unwrap_or(0);
let last_modified = obj
.last_modified()
.map(|d| d.to_string())
.unwrap_or_default();
let name = if key.starts_with(&prefix) {
key[prefix.len()..].to_string()
} else {
key.clone()
};
objects.push(S3Object {
key,
name,
size,
last_modified,
is_dir: false,
is_parent: false,
});
}
}
if is_truncated {
continuation_token = next_token.map(|s| s.to_string());
} else {
break;
}
}
objects.sort_by(|a, b| {
if a.is_dir && !b.is_dir {
std::cmp::Ordering::Less
} else if !a.is_dir && b.is_dir {
std::cmp::Ordering::Greater
} else {
a.name.cmp(&b.name)
}
});
Ok(objects)
}
pub async fn download_file(config: &Config, key: &str, dest: &Path) -> Result<()> {
let (client, bucket) = Self::create_client(config).await?;
let resp = client
.get_object()
.bucket(&bucket)
.key(key)
.send()
.await
.context("Failed to get object")?;
let mut body = resp.body.into_async_read();
let mut file = tokio::fs::File::create(dest)
.await
.context("Failed to create local file")?;
tokio::io::copy(&mut body, &mut file)
.await
.context("Failed to write to file")?;
Ok(())
}
pub async fn delete_file(config: &Config, key: &str) -> Result<()> {
let (client, bucket) = Self::create_client(config).await?;
client
.delete_object()
.bucket(&bucket)
.key(key)
.send()
.await
.context("Failed to delete object")?;
Ok(())
}
pub async fn is_folder_empty(config: &Config, folder_key: &str) -> Result<bool> {
let (client, bucket) = Self::create_client(config).await?;
let mut prefix = folder_key.to_string();
if !prefix.ends_with('/') {
prefix.push('/');
}
let resp = client
.list_objects_v2()
.bucket(&bucket)
.prefix(&prefix)
.delimiter("/")
.max_keys(2)
.send()
.await
.context("Failed to list folder contents")?;
let has_children = resp
.common_prefixes
.as_ref()
.map(|c| !c.is_empty())
.unwrap_or(false)
|| resp.contents.as_ref().is_some_and(|contents| {
contents
.iter()
.any(|obj| obj.key().map(|k| k != prefix).unwrap_or(false))
});
Ok(!has_children)
}
pub async fn delete_folder_recursive(config: &Config, folder_key: &str) -> Result<u64> {
let (client, bucket) = Self::create_client(config).await?;
let mut prefix = folder_key.to_string();
if !prefix.ends_with('/') {
prefix.push('/');
}
let mut continuation_token = None;
let mut deleted: u64 = 0;
loop {
let mut req = client.list_objects_v2().bucket(&bucket).prefix(&prefix);
if let Some(token) = continuation_token.take() {
req = req.continuation_token(token);
}
let resp = req.send().await.context("Failed to list folder contents")?;
let mut identifiers: Vec<ObjectIdentifier> = Vec::new();
let is_truncated = resp.is_truncated().unwrap_or(false);
let next_token = resp.next_continuation_token.clone();
let objects = resp.contents.unwrap_or_default();
for obj in objects {
if let Some(key) = obj.key {
identifiers.push(ObjectIdentifier::builder().key(key).build()?);
}
}
if identifiers.is_empty() {
let _ = client
.delete_object()
.bucket(&bucket)
.key(&prefix)
.send()
.await;
} else {
let count = identifiers.len() as u64;
let delete = Delete::builder()
.set_objects(Some(identifiers))
.quiet(true)
.build()?;
client
.delete_objects()
.bucket(&bucket)
.delete(delete)
.send()
.await
.context("Failed to delete folder objects")?;
deleted += count;
}
if is_truncated {
continuation_token = next_token.map(|s| s.to_string());
} else {
break;
}
}
Ok(deleted)
}
pub async fn create_folder(config: &Config, folder_key: &str) -> Result<()> {
let (client, bucket) = Self::create_client(config).await?;
let key = if folder_key.ends_with('/') {
folder_key.to_string()
} else {
format!("{}/", folder_key)
};
client
.put_object()
.bucket(&bucket)
.key(&key)
.content_length(0)
.send()
.await
.context("Failed to create folder")?;
info!("Created S3 folder: {}", key);
Ok(())
}
pub async fn check_connection(config: &Config) -> Result<String> {
let (client, bucket) = match Self::create_client(config).await {
Ok(c) => c,
Err(e) => return Err(e),
};
match client.list_buckets().send().await {
Ok(_) => Ok(format!("Connected to S3 successfully (Bucket: {})", bucket)),
Err(e) => {
match client.head_bucket().bucket(&bucket).send().await {
Ok(_) => Ok(format!(
"Connected to S3 successfully (Access to '{}' confirmed)",
bucket
)),
Err(_) => Err(anyhow::anyhow!("S3 Connection Failed: {}", e)),
}
}
}
}
#[allow(clippy::too_many_arguments)]
pub async fn upload_file(
&self,
config: &Config,
path: &str,
job_id: i64,
progress: Arc<AsyncMutex<HashMap<i64, ProgressInfo>>>,
conn_mutex: Arc<Mutex<Connection>>,
initial_upload_id: Option<String>,
cancellation_token: Arc<AtomicBool>,
) -> Result<bool> {
let (client, bucket) = Self::create_client(config).await?;
let prefix = config.s3_prefix.as_deref().unwrap_or_default();
let concurrency = config.concurrency_upload_parts.max(1);
let file_path = Path::new(path);
let s3_key_path = {
let conn = lock_mutex(&conn_mutex)?;
if let Ok(Some(job)) = db::get_job(&conn, job_id) {
if let Some(key) = job.s3_key {
key
} else {
let source = Path::new(&job.source_path);
source
.file_name()
.map(|n| n.to_string_lossy().to_string())
.unwrap_or_else(|| {
file_path
.file_name()
.map(|n| n.to_string_lossy().to_string())
.unwrap_or_else(|| "file".to_string())
})
}
} else {
file_path
.file_name()
.map(|n| n.to_string_lossy().to_string())
.unwrap_or_else(|| "file".to_string())
}
};
let key = format!("{}{}", prefix, s3_key_path);
let mut upload_id = initial_upload_id.clone();
let mut completed_parts = Vec::new();
let mut completed_indices = HashSet::new();
let part_hashes = Arc::new(Mutex::new(HashMap::<i32, Vec<u8>>::new()));
if let Some(uid) = &upload_id {
info!("Resuming upload for job {} (Upload ID: {})", job_id, uid);
{
let mut p = lock_async_mutex(&progress).await;
p.insert(
job_id,
ProgressInfo {
percent: -1.0,
details: "Resuming: Listing parts...".to_string(),
parts_done: 0,
parts_total: 0,
},
);
}
let list_parts_output = client
.list_parts()
.bucket(&bucket)
.key(&key)
.upload_id(uid)
.send()
.await;
match list_parts_output {
Ok(output) => {
if let Some(parts) = output.parts {
for p in parts {
if let Some(pn) = p.part_number
&& let Some(etag) = p.e_tag
{
let mut builder =
CompletedPart::builder().part_number(pn).e_tag(etag);
if let Some(cs) = p.checksum_sha256 {
if let Ok(bytes) = general_purpose::STANDARD.decode(&cs)
&& let Ok(mut map) = part_hashes.lock()
{
map.insert(pn, bytes);
}
builder = builder.checksum_sha256(cs);
}
completed_parts.push(builder.build());
completed_indices.insert(pn);
}
}
}
}
Err(_) => {
upload_id = None;
}
}
}
if upload_id.is_none() {
let create_output = client
.create_multipart_upload()
.bucket(&bucket)
.key(&key)
.checksum_algorithm(aws_sdk_s3::types::ChecksumAlgorithm::Sha256)
.send()
.await
.map_err(|e| {
let service_err = match e.as_service_error() {
Some(s) => format!(
"{}: {}",
s.code().unwrap_or("Unknown"),
s.message().unwrap_or("No message")
),
None => e.to_string(),
};
error!(
"Failed to create multipart upload: {} (Bucket: {}, Key: {})",
service_err, bucket, key
);
anyhow::anyhow!(
"Failed to create multipart upload: {} (Bucket: {}, Key: {})",
service_err,
bucket,
key
)
})?;
let new_uid = create_output
.upload_id()
.context("No upload ID")?
.to_string();
info!(
"Created new multipart upload for job {}: {}",
job_id, new_uid
);
{
let conn = lock_mutex(&conn_mutex)?;
db::update_job_upload_id(&conn, job_id, &new_uid)?;
}
upload_id = Some(new_uid);
}
let upload_id =
upload_id.ok_or_else(|| anyhow::anyhow!("Failed to obtain or create Upload ID"))?;
let mut file = File::open(path).await.context("Failed to open file")?;
let file_size = file.metadata().await?.len();
let min_part_size = 5 * 1024 * 1024; let max_parts = 10000;
let config_part_size = (config.part_size_mb * 1024 * 1024) as usize;
let parallel_part_size = (file_size as usize)
.checked_div(concurrency)
.unwrap_or(file_size as usize);
let mut part_size = config_part_size.min(parallel_part_size).max(min_part_size);
let required_min_part = file_size.div_ceil(max_parts);
part_size = part_size.max(required_min_part as usize);
let semaphore = Arc::new(tokio::sync::Semaphore::new(concurrency));
let uploaded_bytes = Arc::new(std::sync::atomic::AtomicU64::new(0));
let initial_bytes: u64 = completed_indices.len() as u64 * part_size as u64;
uploaded_bytes.store(initial_bytes, std::sync::atomic::Ordering::Relaxed);
let net_bytes = uploaded_bytes.clone();
let current_part = Arc::new(std::sync::atomic::AtomicU64::new(1));
let monitor_progress = progress.clone();
let monitor_net = net_bytes.clone();
let m_job_id = job_id;
let m_file_size = file_size;
let m_part_size = part_size;
let m_total_parts = (file_size as usize).div_ceil(part_size);
let monitor_handle = tokio::spawn(async move {
let start_time = std::time::Instant::now();
loop {
tokio::time::sleep(tokio::time::Duration::from_millis(250)).await;
let n = monitor_net.load(std::sync::atomic::Ordering::Relaxed);
let pct = (n as f64 / m_file_size as f64) * 100.0;
let part_size_display = if m_part_size >= 1024 * 1024 * 1024 {
format!("{:.2}GB", m_part_size as f64 / 1024.0 / 1024.0 / 1024.0)
} else {
format!("{:.2}MB", m_part_size as f64 / 1024.0 / 1024.0)
};
let elapsed = start_time.elapsed();
let elapsed_str = if elapsed.as_secs() >= 60 {
format!("{}m{}s", elapsed.as_secs() / 60, elapsed.as_secs() % 60)
} else {
format!("{}s", elapsed.as_secs())
};
let parts_done = n / m_part_size as u64;
let details = format!(
"{}/{} parts, {} sized [{}]",
parts_done, m_total_parts, part_size_display, elapsed_str
);
{
let mut p = lock_async_mutex(&monitor_progress).await;
p.insert(
m_job_id,
ProgressInfo {
percent: pct,
details,
parts_done: parts_done as usize,
parts_total: m_total_parts,
},
);
}
if n >= m_file_size {
}
}
});
let _monitor_guard = TaskGuard(monitor_handle);
let mut handles: Vec<tokio::task::JoinHandle<Result<CompletedPart>>> = Vec::new();
let mut part_number = 1;
loop {
if cancellation_token.load(Ordering::Relaxed) {
for h in handles {
h.abort();
}
return Ok(false);
}
if completed_indices.contains(&part_number) {
use std::io::SeekFrom;
let current_pos = (part_number as u64 - 1) * part_size as u64;
file.seek(SeekFrom::Start(current_pos + part_size as u64))
.await
.unwrap_or(current_pos);
part_number += 1;
current_part.store(part_number as u64, std::sync::atomic::Ordering::Relaxed);
continue;
}
let permit = semaphore
.clone()
.acquire_owned()
.await
.context("Semaphore closed")?;
let mut chunk = vec![0u8; part_size];
use std::io::SeekFrom;
let offset = (part_number as u64 - 1) * (part_size as u64);
file.seek(SeekFrom::Start(offset)).await?;
let mut bytes_read = 0;
while bytes_read < part_size {
let n = file
.read(&mut chunk[bytes_read..])
.await
.context("Failed to read file chunk")?;
if n == 0 {
break;
}
bytes_read += n;
}
if bytes_read == 0 {
drop(permit);
break;
}
chunk.truncate(bytes_read);
let client = client.clone();
let bucket = bucket.to_string();
let key = key.to_string();
let uid_clone = upload_id.clone();
let uploaded_bytes = uploaded_bytes.clone(); let part_hashes_clone = part_hashes.clone();
let handle = tokio::spawn(async move {
let _permit = permit;
let mut hasher = Sha256::new();
hasher.update(&chunk);
let hash = hasher.finalize();
let hash_bytes = hash.to_vec();
let checksum_sha256 = general_purpose::STANDARD.encode(&hash_bytes);
let body = aws_sdk_s3::primitives::ByteStream::from(chunk);
let upload_part_output = client
.upload_part()
.bucket(bucket)
.key(key)
.upload_id(uid_clone)
.part_number(part_number)
.body(body)
.checksum_sha256(checksum_sha256)
.send()
.await
.map_err(|e| anyhow::anyhow!("Failed to upload part {}: {}", part_number, e))?;
if let Ok(mut map) = part_hashes_clone.lock() {
map.insert(part_number, hash_bytes);
}
uploaded_bytes.fetch_add(bytes_read as u64, std::sync::atomic::Ordering::Relaxed);
if let Some(etag) = upload_part_output.e_tag() {
Ok(CompletedPart::builder()
.e_tag(etag)
.part_number(part_number)
.checksum_sha256(upload_part_output.checksum_sha256().unwrap_or_default())
.build())
} else {
Err(anyhow::anyhow!("No ETag"))
}
});
handles.push(handle);
part_number += 1;
current_part.store(part_number as u64, std::sync::atomic::Ordering::Relaxed);
}
for handle in handles {
let part = handle.await??;
completed_parts.push(part);
}
{
let mut p = lock_async_mutex(&progress).await;
if let Some(mut info) = p.get(&job_id).cloned() {
info.details = "Finalizing S3...".to_string();
p.insert(job_id, info);
}
}
completed_parts.sort_by_key(|a| a.part_number());
let completed_upload = CompletedMultipartUpload::builder()
.set_parts(Some(completed_parts.clone()))
.build();
let complete_output = client
.complete_multipart_upload()
.bucket(bucket)
.key(&key)
.upload_id(upload_id)
.multipart_upload(completed_upload)
.send()
.await
.map_err(|e| anyhow::anyhow!("Failed complete: {}", e))?;
{
let conn = lock_mutex(&conn_mutex)?;
let remote_checksum = complete_output.checksum_sha256();
let local_checksum = if let Ok(map) = part_hashes.lock() {
let mut parts: Vec<_> = map.iter().collect();
parts.sort_by_key(|(k, _)| **k);
let mut hasher = Sha256::new();
for (_, bytes) in parts {
hasher.update(bytes);
}
let composite_hash = hasher.finalize();
let b64 = general_purpose::STANDARD.encode(composite_hash);
let count = map.len();
if count > 0 {
Some(format!("{}-{}", b64, count))
} else {
None
}
} else {
None
};
let _ =
db::update_job_checksums(&conn, job_id, local_checksum.as_deref(), remote_checksum);
}
{
let mut p = lock_async_mutex(&progress).await;
p.remove(&job_id);
let conn = lock_mutex(&conn_mutex)?;
let _ = conn.execute(
"UPDATE jobs SET s3_upload_id = NULL WHERE id = ?",
rusqlite::params![job_id],
);
}
Ok(true)
}
#[allow(dead_code)]
pub fn calculate_part_size(
file_size: u64,
config_part_size_mb: u64,
concurrency: usize,
) -> usize {
let min_part_size = 5 * 1024 * 1024; let max_parts = 10000u64;
let config_part_size = (config_part_size_mb * 1024 * 1024) as usize;
let parallel_part_size = (file_size as usize)
.checked_div(concurrency)
.unwrap_or(file_size as usize);
let mut part_size = config_part_size.min(parallel_part_size).max(min_part_size);
let required_min_part = file_size.div_ceil(max_parts);
part_size = part_size.max(required_min_part as usize);
part_size
}
#[allow(dead_code)]
pub fn calculate_checksum(data: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(data);
let hash = hasher.finalize();
general_purpose::STANDARD.encode(hash)
}
#[allow(dead_code)]
pub fn calculate_composite_checksum(part_hashes: Vec<Vec<u8>>) -> String {
let count = part_hashes.len();
let mut hasher = Sha256::new();
for hash_bytes in &part_hashes {
hasher.update(hash_bytes);
}
let composite_hash = hasher.finalize();
let b64 = general_purpose::STANDARD.encode(composite_hash);
format!("{}-{}", b64, count)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calculate_part_size_respects_minimum() {
let part_size = Uploader::calculate_part_size(
10 * 1024 * 1024, 1, 4, );
assert_eq!(part_size, 5 * 1024 * 1024, "Should enforce 5MB minimum");
}
#[test]
fn test_calculate_part_size_uses_config_default() {
let part_size = Uploader::calculate_part_size(
1024 * 1024 * 1024, 128, 4, );
assert_eq!(part_size, 128 * 1024 * 1024);
}
#[test]
fn test_calculate_part_size_parallelism_optimization() {
let part_size = Uploader::calculate_part_size(
100 * 1024 * 1024, 128, 8, );
assert!(part_size >= 5 * 1024 * 1024);
assert!(part_size <= 128 * 1024 * 1024);
}
#[test]
fn test_calculate_part_size_respects_10k_limit() {
let file_size = 5_000_000_000_000u64; let part_size = Uploader::calculate_part_size(
file_size, 128, 4,
);
let required_min = (file_size / 10000) as usize;
assert!(
part_size >= required_min,
"Part size {} should be at least {} to stay under 10k parts",
part_size,
required_min
);
let total_parts = (file_size as usize).div_ceil(part_size);
assert!(
total_parts <= 10000,
"Should not exceed 10k parts, got {}",
total_parts
);
}
#[test]
fn test_calculate_part_size_edge_case_exact_10k_parts() {
let file_size = 10000 * 5 * 1024 * 1024u64;
let part_size = Uploader::calculate_part_size(file_size, 128, 4);
let total_parts = (file_size as usize).div_ceil(part_size);
assert!(total_parts <= 10000);
}
#[test]
fn test_calculate_part_size_small_file() {
let part_size = Uploader::calculate_part_size(
1024 * 1024, 128,
4,
);
assert_eq!(part_size, 5 * 1024 * 1024);
}
#[test]
fn test_calculate_checksum_empty() {
let checksum = Uploader::calculate_checksum(&[]);
assert_eq!(checksum, "47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU=");
}
#[test]
fn test_calculate_checksum_known_value() {
let data = b"Hello, World!";
let checksum = Uploader::calculate_checksum(data);
assert_eq!(checksum, "3/1gIbsr1bCvZ2KQgJ7DpTGR3YHH9wpLKGiKNiGCmG8=");
}
#[test]
fn test_calculate_checksum_different_inputs() {
let checksum1 = Uploader::calculate_checksum(b"data1");
let checksum2 = Uploader::calculate_checksum(b"data2");
assert_ne!(
checksum1, checksum2,
"Different inputs should produce different checksums"
);
}
#[test]
fn test_calculate_checksum_deterministic() {
let data = b"test data for checksumming";
let checksum1 = Uploader::calculate_checksum(data);
let checksum2 = Uploader::calculate_checksum(data);
assert_eq!(
checksum1, checksum2,
"Same input should produce same checksum"
);
}
#[test]
fn test_calculate_checksum_large_data() {
let data = vec![0xAB; 5 * 1024 * 1024]; let checksum = Uploader::calculate_checksum(&data);
assert!(!checksum.is_empty());
assert!(checksum.len() > 40); }
#[test]
fn test_calculate_composite_checksum_single_part() {
let hash1 = vec![0x01, 0x02, 0x03, 0x04];
let composite = Uploader::calculate_composite_checksum(vec![hash1]);
assert!(composite.ends_with("-1"), "Should end with part count");
assert!(composite.len() > 5);
}
#[test]
fn test_calculate_composite_checksum_multiple_parts() {
let hash1 = vec![0x01; 32]; let hash2 = vec![0x02; 32];
let hash3 = vec![0x03; 32];
let composite = Uploader::calculate_composite_checksum(vec![hash1, hash2, hash3]);
assert!(
composite.ends_with("-3"),
"Should end with part count: {}",
composite
);
let parts: Vec<&str> = composite.split('-').collect();
assert_eq!(parts.len(), 2);
assert_eq!(parts[1], "3");
}
#[test]
fn test_calculate_composite_checksum_format() {
let hash1 = vec![0xFF; 32];
let composite = Uploader::calculate_composite_checksum(vec![hash1]);
assert!(composite.contains('-'));
let parts: Vec<&str> = composite.split('-').collect();
assert_eq!(parts.len(), 2);
assert!(!parts[0].is_empty());
assert!(parts[1].parse::<usize>().is_ok());
}
#[test]
fn test_calculate_composite_checksum_order_matters() {
let hash1 = vec![0x01; 32];
let hash2 = vec![0x02; 32];
let composite1 = Uploader::calculate_composite_checksum(vec![hash1.clone(), hash2.clone()]);
let composite2 = Uploader::calculate_composite_checksum(vec![hash2, hash1]);
assert_ne!(
composite1, composite2,
"Part order should affect composite checksum"
);
}
#[test]
fn test_calculate_composite_checksum_matches_s3_format() {
let hash1 = vec![0xAB; 32];
let hash2 = vec![0xCD; 32];
let composite = Uploader::calculate_composite_checksum(vec![hash1, hash2]);
let parts: Vec<&str> = composite.split('-').collect();
assert_eq!(parts.len(), 2, "Should have exactly one dash");
assert_eq!(parts[1], "2", "Should have correct part count");
assert_eq!(parts[0].len(), 44, "Base64 of SHA256 should be 44 chars");
}
#[test]
fn test_part_size_for_various_file_sizes() {
let test_cases = vec![
(1024 * 1024, 128, 4, 5 * 1024 * 1024), (100 * 1024 * 1024, 128, 4, 25 * 1024 * 1024), (1024 * 1024 * 1024, 128, 4, 128 * 1024 * 1024), (10 * 1024 * 1024 * 1024, 128, 4, 128 * 1024 * 1024), ];
for (file_size, config_mb, concurrency, expected_min) in test_cases {
let part_size = Uploader::calculate_part_size(file_size, config_mb, concurrency);
assert!(
part_size >= expected_min,
"File size {} should have part size >= {}, got {}",
file_size,
expected_min,
part_size
);
let total_parts = (file_size as usize).div_ceil(part_size);
assert!(
total_parts <= 10000,
"File size {} created {} parts (exceeds 10k)",
file_size,
total_parts
);
}
}
#[test]
fn test_checksum_workflow_single_part() {
let data = b"Small file content";
let part_checksum = Uploader::calculate_checksum(data);
assert!(!part_checksum.is_empty());
let mut hasher = Sha256::new();
hasher.update(data);
let hash_bytes = hasher.finalize().to_vec();
let composite = Uploader::calculate_composite_checksum(vec![hash_bytes]);
assert!(composite.ends_with("-1"));
}
#[test]
fn test_checksum_workflow_multipart() {
let part1 = b"Part 1 data";
let part2 = b"Part 2 data";
let part3 = b"Part 3 data";
let mut hashes = Vec::new();
for part in &[part1, part2, part3] {
let mut hasher = Sha256::new();
hasher.update(part);
hashes.push(hasher.finalize().to_vec());
}
let composite = Uploader::calculate_composite_checksum(hashes);
assert!(composite.ends_with("-3"));
let parts: Vec<&str> = composite.split('-').collect();
assert_eq!(parts.len(), 2);
assert_eq!(parts[1], "3");
}
}