use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use anyhow::{Result, anyhow};
use aws_sdk_s3::types::RequestPayer;
use leaky_bucket::RateLimiter;
use tracing::{info, trace};
use s3util_rs::Config;
use s3util_rs::storage::Storage;
use s3util_rs::storage::StorageFactory;
use s3util_rs::storage::local::LocalStorageFactory;
use s3util_rs::storage::s3::S3StorageFactory;
use s3util_rs::transfer::{TransferDirection, TransferOutcome, detect_direction};
use s3util_rs::types::StoragePath;
use s3util_rs::types::token::{PipelineCancellationToken, create_pipeline_cancellation_token};
pub mod cp;
pub mod create_bucket;
pub mod ctrl_c_handler;
pub mod delete_bucket;
pub mod delete_bucket_cors;
pub mod delete_bucket_encryption;
pub mod delete_bucket_lifecycle_configuration;
pub mod delete_bucket_policy;
pub mod delete_bucket_tagging;
pub mod delete_bucket_website;
pub mod delete_object_tagging;
pub mod delete_public_access_block;
pub mod get_bucket_cors;
pub mod get_bucket_encryption;
pub mod get_bucket_lifecycle_configuration;
pub mod get_bucket_logging;
pub mod get_bucket_notification_configuration;
pub mod get_bucket_policy;
pub mod get_bucket_tagging;
pub mod get_bucket_versioning;
pub mod get_bucket_website;
pub mod get_object_tagging;
pub mod get_public_access_block;
pub mod head_bucket;
pub mod head_object;
pub mod indicator;
pub mod mv;
pub mod put_bucket_cors;
pub mod put_bucket_encryption;
pub mod put_bucket_lifecycle_configuration;
pub mod put_bucket_logging;
pub mod put_bucket_notification_configuration;
pub mod put_bucket_policy;
pub mod put_bucket_tagging;
pub mod put_bucket_versioning;
pub mod put_bucket_website;
pub mod put_object_tagging;
pub mod put_public_access_block;
pub mod rm;
pub mod tagging;
pub mod ui_config;
pub use cp::run_cp;
pub use create_bucket::run_create_bucket;
pub use delete_bucket::run_delete_bucket;
pub use delete_bucket_cors::run_delete_bucket_cors;
pub use delete_bucket_encryption::run_delete_bucket_encryption;
pub use delete_bucket_lifecycle_configuration::run_delete_bucket_lifecycle_configuration;
pub use delete_bucket_policy::run_delete_bucket_policy;
pub use delete_bucket_tagging::run_delete_bucket_tagging;
pub use delete_bucket_website::run_delete_bucket_website;
pub use delete_object_tagging::run_delete_object_tagging;
pub use delete_public_access_block::run_delete_public_access_block;
pub use get_bucket_cors::run_get_bucket_cors;
pub use get_bucket_encryption::run_get_bucket_encryption;
pub use get_bucket_lifecycle_configuration::run_get_bucket_lifecycle_configuration;
pub use get_bucket_logging::run_get_bucket_logging;
pub use get_bucket_notification_configuration::run_get_bucket_notification_configuration;
pub use get_bucket_policy::run_get_bucket_policy;
pub use get_bucket_tagging::run_get_bucket_tagging;
pub use get_bucket_versioning::run_get_bucket_versioning;
pub use get_bucket_website::run_get_bucket_website;
pub use get_object_tagging::run_get_object_tagging;
pub use get_public_access_block::run_get_public_access_block;
pub use head_bucket::run_head_bucket;
pub use head_object::run_head_object;
pub use mv::run_mv;
pub use put_bucket_cors::run_put_bucket_cors;
pub use put_bucket_encryption::run_put_bucket_encryption;
pub use put_bucket_lifecycle_configuration::run_put_bucket_lifecycle_configuration;
pub use put_bucket_logging::run_put_bucket_logging;
pub use put_bucket_notification_configuration::run_put_bucket_notification_configuration;
pub use put_bucket_policy::run_put_bucket_policy;
pub use put_bucket_tagging::run_put_bucket_tagging;
pub use put_bucket_versioning::run_put_bucket_versioning;
pub use put_bucket_website::run_put_bucket_website;
pub use put_object_tagging::run_put_object_tagging;
pub use put_public_access_block::run_put_public_access_block;
pub use rm::run_rm;
const REFILL_PER_INTERVAL_DIVIDER: usize = 10;
fn build_rate_limiter(config: &Config) -> Option<Arc<RateLimiter>> {
config.rate_limit_bandwidth.map(|bandwidth| {
let refill = bandwidth as usize / REFILL_PER_INTERVAL_DIVIDER;
Arc::new(
RateLimiter::builder()
.max(bandwidth as usize)
.initial(bandwidth as usize)
.refill(refill)
.fair(true)
.build(),
)
})
}
#[derive(Debug)]
pub enum ExitStatus {
Success,
Warning,
NotFound,
Cancelled,
}
impl ExitStatus {
pub fn code(&self) -> i32 {
match self {
ExitStatus::Success => EXIT_CODE_SUCCESS,
ExitStatus::Warning => EXIT_CODE_WARNING,
ExitStatus::NotFound => EXIT_CODE_NOT_FOUND,
ExitStatus::Cancelled => EXIT_CODE_CANCELLED,
}
}
}
pub const EXIT_CODE_SUCCESS: i32 = 0;
pub const EXIT_CODE_ERROR: i32 = 1;
pub const EXIT_CODE_WARNING: i32 = 3;
pub const EXIT_CODE_NOT_FOUND: i32 = 4;
pub const EXIT_CODE_CANCELLED: i32 = 130;
pub struct CopyPhase {
pub transfer_result: Result<TransferOutcome>,
pub source_storage: Storage,
pub source_key: String,
pub cancellation_token: PipelineCancellationToken,
pub cancelled: bool,
pub has_warning: bool,
}
pub async fn run_copy_phase(config: Config) -> Result<CopyPhase> {
let cancellation_token = create_pipeline_cancellation_token();
ctrl_c_handler::spawn_ctrl_c_handler(cancellation_token.clone());
let (stats_sender, stats_receiver) = async_channel::unbounded();
let (source_str, target_str) = get_path_strings(&config.source, &config.target);
let direction = detect_direction(&source_str, &target_str)?;
trace!(direction = ?direction, "detected transfer direction");
check_local_source_not_directory(&config.source, &direction)?;
let (source_key, target_key) = extract_keys(&config)?;
let resolved_target_display = format_target_path(&config.target, &target_key);
let show_progress = ui_config::is_progress_indicator_needed(&config);
let show_result = ui_config::is_show_result_needed(&config);
let log_sync_summary = config.tracing_config.is_some();
let indicator_handle = indicator::show_indicator(
stats_receiver,
show_progress,
show_result,
log_sync_summary,
resolved_target_display,
source_key.clone(),
target_key.clone(),
config.dry_run,
);
let has_warning = Arc::new(AtomicBool::new(false));
let rate_limit_bandwidth = build_rate_limiter(&config);
let (transfer_result, source_for_caller) = match direction {
TransferDirection::LocalToS3 => {
let target_request_payer = if config.target_request_payer {
Some(RequestPayer::Requester)
} else {
None
};
let source = LocalStorageFactory::create(
config.clone(),
empty_local_storage_path(),
cancellation_token.clone(),
stats_sender.clone(),
None,
None,
rate_limit_bandwidth.clone(),
has_warning.clone(),
None,
)
.await;
let source_for_caller = dyn_clone::clone_box(&*source);
let target = S3StorageFactory::create(
config.clone(),
empty_s3_storage_path(&config.target),
cancellation_token.clone(),
stats_sender.clone(),
config.target_client_config.clone(),
target_request_payer,
rate_limit_bandwidth.clone(),
has_warning.clone(),
None,
)
.await;
let result = if config.dry_run {
info!(
source = %source_str,
target = %target_str,
"[dry-run] would copy object."
);
Ok(s3util_rs::transfer::TransferOutcome::default())
} else {
s3util_rs::transfer::local_to_s3::transfer(
&config,
source,
target,
&source_key,
&target_key,
cancellation_token.clone(),
stats_sender.clone(),
)
.await
};
(result, source_for_caller)
}
TransferDirection::S3ToLocal => {
let source_request_payer = if config.source_request_payer {
Some(RequestPayer::Requester)
} else {
None
};
let source = S3StorageFactory::create(
config.clone(),
empty_s3_storage_path(&config.source),
cancellation_token.clone(),
stats_sender.clone(),
config.source_client_config.clone(),
source_request_payer,
rate_limit_bandwidth.clone(),
has_warning.clone(),
None,
)
.await;
let source_for_caller = dyn_clone::clone_box(&*source);
let target = LocalStorageFactory::create(
config.clone(),
empty_local_storage_path(),
cancellation_token.clone(),
stats_sender.clone(),
None,
None,
rate_limit_bandwidth.clone(),
has_warning.clone(),
None,
)
.await;
let result = if config.dry_run {
info!(
source = %source_str,
target = %target_str,
"[dry-run] would copy object."
);
Ok(s3util_rs::transfer::TransferOutcome::default())
} else {
s3util_rs::transfer::s3_to_local::transfer(
&config,
source,
target,
&source_key,
&target_key,
cancellation_token.clone(),
stats_sender.clone(),
)
.await
};
(result, source_for_caller)
}
TransferDirection::S3ToS3 => {
let source_request_payer = if config.source_request_payer {
Some(RequestPayer::Requester)
} else {
None
};
let target_request_payer = if config.target_request_payer {
Some(RequestPayer::Requester)
} else {
None
};
let source = S3StorageFactory::create(
config.clone(),
empty_s3_storage_path(&config.source),
cancellation_token.clone(),
stats_sender.clone(),
config.source_client_config.clone(),
source_request_payer,
rate_limit_bandwidth.clone(),
has_warning.clone(),
None,
)
.await;
let source_for_caller = dyn_clone::clone_box(&*source);
let target = S3StorageFactory::create(
config.clone(),
empty_s3_storage_path(&config.target),
cancellation_token.clone(),
stats_sender.clone(),
config.target_client_config.clone(),
target_request_payer,
rate_limit_bandwidth.clone(),
has_warning.clone(),
None,
)
.await;
let result = if config.dry_run {
info!(
source = %source_str,
target = %target_str,
"[dry-run] would copy object."
);
Ok(s3util_rs::transfer::TransferOutcome::default())
} else {
s3util_rs::transfer::s3_to_s3::transfer(
&config,
source,
target,
&source_key,
&target_key,
cancellation_token.clone(),
stats_sender.clone(),
)
.await
};
(result, source_for_caller)
}
TransferDirection::StdioToS3 => {
let target_request_payer = if config.target_request_payer {
Some(RequestPayer::Requester)
} else {
None
};
let target = S3StorageFactory::create(
config.clone(),
empty_s3_storage_path(&config.target),
cancellation_token.clone(),
stats_sender.clone(),
config.target_client_config.clone(),
target_request_payer,
rate_limit_bandwidth.clone(),
has_warning.clone(),
None,
)
.await;
let source_for_caller = LocalStorageFactory::create(
config.clone(),
empty_local_storage_path(),
cancellation_token.clone(),
stats_sender.clone(),
None,
None,
rate_limit_bandwidth.clone(),
has_warning.clone(),
None,
)
.await;
let result = if config.dry_run {
info!(
source = %source_str,
target = %target_str,
"[dry-run] would copy object."
);
Ok(s3util_rs::transfer::TransferOutcome::default())
} else {
s3util_rs::transfer::stdio_to_s3::transfer(
&config,
target,
&target_key,
tokio::io::stdin(),
cancellation_token.clone(),
stats_sender.clone(),
)
.await
};
(result, source_for_caller)
}
TransferDirection::S3ToStdio => {
let source_request_payer = if config.source_request_payer {
Some(RequestPayer::Requester)
} else {
None
};
let source = S3StorageFactory::create(
config.clone(),
empty_s3_storage_path(&config.source),
cancellation_token.clone(),
stats_sender.clone(),
config.source_client_config.clone(),
source_request_payer,
rate_limit_bandwidth.clone(),
has_warning.clone(),
None,
)
.await;
let source_for_caller = dyn_clone::clone_box(&*source);
let result = if config.dry_run {
info!(
source = %source_str,
target = %target_str,
"[dry-run] would copy object."
);
Ok(s3util_rs::transfer::TransferOutcome::default())
} else {
s3util_rs::transfer::s3_to_stdio::transfer(
&config,
source,
&source_key,
tokio::io::stdout(),
cancellation_token.clone(),
stats_sender.clone(),
)
.await
};
(result, source_for_caller)
}
};
if transfer_result.is_err() {
let _ = stats_sender
.send(s3util_rs::types::SyncStatistics::SyncError { key: String::new() })
.await;
}
stats_sender.close();
let _ = indicator_handle.await;
let cancelled = cancellation_token.is_cancelled();
let has_warning = has_warning.load(std::sync::atomic::Ordering::SeqCst);
Ok(CopyPhase {
transfer_result,
source_storage: source_for_caller,
source_key,
cancellation_token,
cancelled,
has_warning,
})
}
fn get_path_strings(source: &StoragePath, target: &StoragePath) -> (String, String) {
let source_str = match source {
StoragePath::S3 { bucket, prefix } => {
if prefix.is_empty() {
format!("s3://{}", bucket)
} else {
format!("s3://{}/{}", bucket, prefix)
}
}
StoragePath::Local(path) => path.to_string_lossy().to_string(),
StoragePath::Stdio => "-".to_string(),
};
let target_str = match target {
StoragePath::S3 { bucket, prefix } => {
if prefix.is_empty() {
format!("s3://{}", bucket)
} else {
format!("s3://{}/{}", bucket, prefix)
}
}
StoragePath::Local(path) => path.to_string_lossy().to_string(),
StoragePath::Stdio => "-".to_string(),
};
(source_str, target_str)
}
fn extract_keys(config: &Config) -> Result<(String, String)> {
let source_key = match &config.source {
StoragePath::S3 { prefix, .. } => {
if prefix.is_empty() {
return Err(anyhow!("source S3 key is required (e.g. s3://bucket/key)"));
}
prefix.clone()
}
StoragePath::Local(path) => path.to_string_lossy().to_string(),
StoragePath::Stdio => String::new(),
};
let source_basename = std::path::Path::new(&source_key)
.file_name()
.map(|f| f.to_string_lossy().to_string())
.unwrap_or(source_key.clone());
let target_key = match &config.target {
StoragePath::S3 { prefix, .. } => {
if prefix.is_empty() || prefix.ends_with('/') {
if source_basename.is_empty() {
return Err(anyhow!(
"target S3 key is required when source is stdin (e.g. s3://bucket/key)"
));
}
format!("{prefix}{source_basename}")
} else {
prefix.clone()
}
}
StoragePath::Local(path) => {
let p = path.clone();
if p.is_dir() || p.to_string_lossy().ends_with(std::path::MAIN_SEPARATOR) {
p.join(&source_basename).to_string_lossy().to_string()
} else {
p.to_string_lossy().to_string()
}
}
StoragePath::Stdio => String::new(),
};
Ok((source_key, target_key))
}
fn format_target_path(target: &StoragePath, target_key: &str) -> String {
match target {
StoragePath::S3 { bucket, .. } => format!("s3://{bucket}/{target_key}"),
StoragePath::Local(_) => target_key.to_string(),
StoragePath::Stdio => "-".to_string(),
}
}
fn check_local_source_not_directory(
source: &StoragePath,
direction: &TransferDirection,
) -> Result<()> {
if !matches!(direction, TransferDirection::LocalToS3) {
return Ok(());
}
if let StoragePath::Local(path) = source
&& path.is_dir()
{
return Err(anyhow!(
"source is a directory: {}. cp command copies a single file.",
path.display()
));
}
Ok(())
}
fn empty_local_storage_path() -> StoragePath {
StoragePath::Local(".".into())
}
fn empty_s3_storage_path(original: &StoragePath) -> StoragePath {
match original {
StoragePath::S3 { bucket, .. } => StoragePath::S3 {
bucket: bucket.clone(),
prefix: String::new(),
},
_ => unreachable!("expected S3 storage path"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use s3util_rs::config::args::{Commands, parse_from_args};
use std::path::PathBuf;
fn build_config(args: Vec<&str>) -> Config {
let cli = parse_from_args(args).unwrap();
let Commands::Cp(cp_args) = cli.command else {
panic!("expected Cp variant");
};
Config::try_from(cp_args).unwrap()
}
#[test]
fn exit_status_codes() {
assert_eq!(ExitStatus::Success.code(), EXIT_CODE_SUCCESS);
assert_eq!(ExitStatus::Warning.code(), EXIT_CODE_WARNING);
assert_eq!(ExitStatus::NotFound.code(), EXIT_CODE_NOT_FOUND);
assert_eq!(ExitStatus::Cancelled.code(), EXIT_CODE_CANCELLED);
assert_eq!(EXIT_CODE_SUCCESS, 0);
assert_eq!(EXIT_CODE_ERROR, 1);
assert_eq!(EXIT_CODE_WARNING, 3);
assert_eq!(EXIT_CODE_NOT_FOUND, 4);
assert_eq!(EXIT_CODE_CANCELLED, 130);
}
#[test]
fn get_path_strings_formats_each_storage_kind() {
let s3_with_prefix = StoragePath::S3 {
bucket: "b".to_string(),
prefix: "k/v".to_string(),
};
let s3_no_prefix = StoragePath::S3 {
bucket: "b".to_string(),
prefix: String::new(),
};
let local = StoragePath::Local(PathBuf::from("/tmp/x"));
let stdio = StoragePath::Stdio;
let (src, tgt) = get_path_strings(&s3_with_prefix, &local);
assert_eq!(src, "s3://b/k/v");
assert_eq!(tgt, "/tmp/x");
let (src, tgt) = get_path_strings(&s3_no_prefix, &stdio);
assert_eq!(src, "s3://b");
assert_eq!(tgt, "-");
let (src, tgt) = get_path_strings(&stdio, &s3_with_prefix);
assert_eq!(src, "-");
assert_eq!(tgt, "s3://b/k/v");
}
#[test]
fn format_target_path_for_each_storage_kind() {
let s3 = StoragePath::S3 {
bucket: "mybucket".to_string(),
prefix: String::new(),
};
assert_eq!(format_target_path(&s3, "k/v.dat"), "s3://mybucket/k/v.dat");
let local = StoragePath::Local(PathBuf::from("/x"));
assert_eq!(format_target_path(&local, "ignored"), "ignored");
assert_eq!(format_target_path(&StoragePath::Stdio, "ignored"), "-");
}
#[test]
fn empty_local_storage_path_is_dot() {
match empty_local_storage_path() {
StoragePath::Local(p) => assert_eq!(p, PathBuf::from(".")),
_ => panic!("expected Local"),
}
}
#[test]
fn empty_s3_storage_path_clears_prefix_keeps_bucket() {
let original = StoragePath::S3 {
bucket: "mybucket".to_string(),
prefix: "some/key".to_string(),
};
match empty_s3_storage_path(&original) {
StoragePath::S3 { bucket, prefix } => {
assert_eq!(bucket, "mybucket");
assert_eq!(prefix, "");
}
_ => panic!("expected S3"),
}
}
#[test]
fn build_rate_limiter_returns_none_when_unset() {
let config = build_config(vec!["s3util", "cp", "/tmp/a", "s3://b/k"]);
assert!(config.rate_limit_bandwidth.is_none());
assert!(build_rate_limiter(&config).is_none());
}
#[test]
fn build_rate_limiter_returns_some_when_set() {
let config = build_config(vec![
"s3util",
"cp",
"--rate-limit-bandwidth",
"10MiB",
"/tmp/a",
"s3://b/k",
]);
assert!(config.rate_limit_bandwidth.is_some());
assert!(build_rate_limiter(&config).is_some());
}
#[test]
fn extract_keys_local_to_s3_object_target() {
let config = build_config(vec!["s3util", "cp", "/tmp/source.dat", "s3://b/key.dat"]);
let (src, tgt) = extract_keys(&config).unwrap();
assert_eq!(src, "/tmp/source.dat");
assert_eq!(tgt, "key.dat");
}
#[test]
fn extract_keys_local_to_s3_bucket_only_uses_basename() {
let config = build_config(vec!["s3util", "cp", "/tmp/source.dat", "s3://b"]);
let (_, tgt) = extract_keys(&config).unwrap();
assert_eq!(tgt, "source.dat");
}
#[test]
fn extract_keys_local_to_s3_prefix_with_slash_appends_basename() {
let config = build_config(vec!["s3util", "cp", "/tmp/source.dat", "s3://b/dir/"]);
let (_, tgt) = extract_keys(&config).unwrap();
assert_eq!(tgt, "dir/source.dat");
}
#[test]
fn extract_keys_s3_to_local_with_no_source_key_errors() {
let dir = tempfile::tempdir().unwrap();
let target = dir.path().join("dst").to_string_lossy().to_string();
let config = build_config(vec!["s3util", "cp", "s3://b", &target]);
let err = extract_keys(&config).unwrap_err();
assert!(err.to_string().contains("source S3 key is required"));
}
#[test]
fn extract_keys_stdio_target_yields_empty_target_key() {
let config = build_config(vec!["s3util", "cp", "s3://b/key", "-"]);
let (src, tgt) = extract_keys(&config).unwrap();
assert_eq!(src, "key");
assert_eq!(tgt, "");
}
#[test]
fn extract_keys_stdio_source_yields_empty_source_key() {
let config = build_config(vec!["s3util", "cp", "-", "s3://b/key"]);
let (src, tgt) = extract_keys(&config).unwrap();
assert_eq!(src, "");
assert_eq!(tgt, "key");
}
#[test]
fn extract_keys_stdio_to_s3_bucket_only_errors() {
let config = build_config(vec!["s3util", "cp", "-", "s3://b"]);
let err = extract_keys(&config).unwrap_err();
assert!(err.to_string().contains("target S3 key is required"));
}
#[test]
fn extract_keys_stdio_to_s3_prefix_with_slash_errors() {
let config = build_config(vec!["s3util", "cp", "-", "s3://b/dir/"]);
let err = extract_keys(&config).unwrap_err();
assert!(err.to_string().contains("target S3 key is required"));
}
#[test]
fn check_local_source_not_directory_rejects_directory() {
let tmp = tempfile::tempdir().unwrap();
let source = StoragePath::Local(tmp.path().to_path_buf());
let err =
check_local_source_not_directory(&source, &TransferDirection::LocalToS3).unwrap_err();
let msg = err.to_string();
assert!(msg.contains("source is a directory"), "message: {msg}");
}
#[test]
fn check_local_source_not_directory_allows_file() {
let tmp = tempfile::NamedTempFile::new().unwrap();
let source = StoragePath::Local(tmp.path().to_path_buf());
check_local_source_not_directory(&source, &TransferDirection::LocalToS3).unwrap();
}
#[test]
fn check_local_source_not_directory_allows_nonexistent_path() {
let source = StoragePath::Local(PathBuf::from("/nonexistent/path/for/test"));
check_local_source_not_directory(&source, &TransferDirection::LocalToS3).unwrap();
}
#[test]
fn check_local_source_not_directory_skips_non_local_to_s3_direction() {
let tmp = tempfile::tempdir().unwrap();
let source = StoragePath::Local(tmp.path().to_path_buf());
for direction in [
TransferDirection::S3ToLocal,
TransferDirection::S3ToS3,
TransferDirection::StdioToS3,
TransferDirection::S3ToStdio,
] {
check_local_source_not_directory(&source, &direction).unwrap();
}
}
#[test]
fn extract_keys_s3_to_existing_local_directory_appends_basename() {
let tmp = tempfile::tempdir().unwrap();
let target_arg = tmp.path().to_string_lossy().to_string();
let config = build_config(vec![
"s3util",
"cp",
"s3://b/remote/file.dat",
target_arg.as_str(),
]);
let (_, tgt) = extract_keys(&config).unwrap();
let expected = tmp.path().join("file.dat").to_string_lossy().to_string();
assert_eq!(tgt, expected);
}
#[test]
fn extract_keys_s3_to_local_path_with_trailing_separator_appends_basename() {
let tmp = tempfile::tempdir().unwrap();
let sep = std::path::MAIN_SEPARATOR;
let target_arg = format!("{}{sep}", tmp.path().to_string_lossy());
let config = build_config(vec![
"s3util",
"cp",
"s3://b/remote/object.bin",
target_arg.as_str(),
]);
let (_, tgt) = extract_keys(&config).unwrap();
assert!(tgt.ends_with("object.bin"), "target was: {tgt}");
}
}