use crate::config::args::value_parser::url;
use crate::config::{CLITimeoutConfig, ClientConfig, RetryConfig, TracingConfig};
use crate::types::{AccessKeys, ClientConfigLocation, S3Credentials};
use aws_sdk_s3::types::RequestPayer;
use aws_smithy_types::checksum_config::RequestChecksumCalculation;
use clap::Parser;
use clap::builder::NonEmptyStringValueParser;
use clap_verbosity_flag::{Verbosity, WarnLevel};
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::Semaphore;
use super::common::{
DEFAULT_ACCELERATE, DEFAULT_AWS_MAX_ATTEMPTS, DEFAULT_AWS_SDK_TRACING,
DEFAULT_DISABLE_COLOR_TRACING, DEFAULT_DISABLE_STALLED_STREAM_PROTECTION,
DEFAULT_FORCE_PATH_STYLE, DEFAULT_INITIAL_BACKOFF_MILLISECONDS, DEFAULT_JSON_TRACING,
DEFAULT_REQUEST_PAYER, DEFAULT_SPAN_EVENTS_TRACING,
};
const DEFAULT_TARGET_NO_SIGN_REQUEST: bool = false;
#[derive(Parser, Clone, Debug)]
pub struct CommonClientArgs {
#[arg(long, env, default_value_t = DEFAULT_JSON_TRACING, help_heading = "Tracing/Logging")]
pub json_tracing: bool,
#[arg(long, env, default_value_t = DEFAULT_AWS_SDK_TRACING, help_heading = "Tracing/Logging")]
pub aws_sdk_tracing: bool,
#[arg(long, env, default_value_t = DEFAULT_SPAN_EVENTS_TRACING, help_heading = "Tracing/Logging")]
pub span_events_tracing: bool,
#[arg(long, env, default_value_t = DEFAULT_DISABLE_COLOR_TRACING, help_heading = "Tracing/Logging")]
pub disable_color_tracing: bool,
#[command(flatten)]
pub verbosity: Verbosity<WarnLevel>,
#[arg(long, env, value_name = "FILE", help_heading = "AWS Configuration")]
pub aws_config_file: Option<PathBuf>,
#[arg(long, env, value_name = "FILE", help_heading = "AWS Configuration")]
pub aws_shared_credentials_file: Option<PathBuf>,
#[arg(long, env, conflicts_with_all = ["target_access_key", "target_secret_access_key", "target_session_token"], help_heading = "AWS Configuration")]
pub target_profile: Option<String>,
#[arg(long, env, conflicts_with_all = ["target_profile"], requires = "target_secret_access_key", help_heading = "AWS Configuration")]
pub target_access_key: Option<String>,
#[arg(long, env, conflicts_with_all = ["target_profile"], requires = "target_access_key", help_heading = "AWS Configuration")]
pub target_secret_access_key: Option<String>,
#[arg(long, env, conflicts_with_all = ["target_profile"], requires = "target_access_key", help_heading = "AWS Configuration")]
pub target_session_token: Option<String>,
#[arg(long, env, value_parser = NonEmptyStringValueParser::new(), help_heading = "AWS Configuration")]
pub target_region: Option<String>,
#[arg(long, env, value_parser = url::check_scheme, help_heading = "AWS Configuration")]
pub target_endpoint_url: Option<String>,
#[arg(long, env, default_value_t = DEFAULT_FORCE_PATH_STYLE, help_heading = "AWS Configuration")]
pub target_force_path_style: bool,
#[arg(long, env, default_value_t = DEFAULT_ACCELERATE, help_heading = "AWS Configuration")]
pub target_accelerate: bool,
#[arg(long, env, default_value_t = DEFAULT_REQUEST_PAYER, help_heading = "AWS Configuration")]
pub target_request_payer: bool,
#[arg(
long,
env,
default_value_t = DEFAULT_TARGET_NO_SIGN_REQUEST,
conflicts_with_all = [
"target_profile",
"target_access_key",
"target_secret_access_key",
"target_session_token",
"target_request_payer",
],
help_heading = "AWS Configuration"
)]
pub target_no_sign_request: bool,
#[arg(long, env, default_value_t = DEFAULT_DISABLE_STALLED_STREAM_PROTECTION, help_heading = "AWS Configuration")]
pub disable_stalled_stream_protection: bool,
#[arg(long, env, default_value_t = DEFAULT_AWS_MAX_ATTEMPTS, help_heading = "Retry Options")]
pub aws_max_attempts: u32,
#[arg(long, env, default_value_t = DEFAULT_INITIAL_BACKOFF_MILLISECONDS, help_heading = "Retry Options")]
pub initial_backoff_milliseconds: u64,
#[arg(long, env, help_heading = "Timeout Options")]
pub operation_timeout_milliseconds: Option<u64>,
#[arg(long, env, help_heading = "Timeout Options")]
pub operation_attempt_timeout_milliseconds: Option<u64>,
#[arg(long, env, help_heading = "Timeout Options")]
pub connect_timeout_milliseconds: Option<u64>,
#[arg(long, env, help_heading = "Timeout Options")]
pub read_timeout_milliseconds: Option<u64>,
#[arg(
long,
env,
value_name = "SHELL",
value_parser = clap_complete::shells::Shell::from_str,
help_heading = "Advanced",
long_help = r#"Generate a auto completions script.
Valid choices: bash, fish, zsh, powershell, elvish."#
)]
pub auto_complete_shell: Option<clap_complete::shells::Shell>,
}
impl CommonClientArgs {
pub fn build_client_config(&self) -> ClientConfig {
let credential = if self.target_no_sign_request {
S3Credentials::NoSignRequest
} else if let Some(profile) = self.target_profile.clone() {
S3Credentials::Profile(profile)
} else if let Some(access_key) = self.target_access_key.clone() {
S3Credentials::Credentials {
access_keys: AccessKeys {
access_key,
secret_access_key: self.target_secret_access_key.clone().expect(
"clap requires --target-secret-access-key alongside --target-access-key",
),
session_token: self.target_session_token.clone(),
},
}
} else {
S3Credentials::FromEnvironment
};
let request_payer = if self.target_request_payer {
Some(RequestPayer::Requester)
} else {
None
};
ClientConfig {
client_config_location: ClientConfigLocation {
aws_config_file: self.aws_config_file.clone(),
aws_shared_credentials_file: self.aws_shared_credentials_file.clone(),
},
credential,
region: self.target_region.clone(),
endpoint_url: self.target_endpoint_url.clone(),
force_path_style: self.target_force_path_style,
accelerate: self.target_accelerate,
request_payer,
retry_config: RetryConfig {
aws_max_attempts: self.aws_max_attempts,
initial_backoff_milliseconds: self.initial_backoff_milliseconds,
},
cli_timeout_config: CLITimeoutConfig {
operation_timeout_milliseconds: self.operation_timeout_milliseconds,
operation_attempt_timeout_milliseconds: self.operation_attempt_timeout_milliseconds,
connect_timeout_milliseconds: self.connect_timeout_milliseconds,
read_timeout_milliseconds: self.read_timeout_milliseconds,
},
disable_stalled_stream_protection: self.disable_stalled_stream_protection,
request_checksum_calculation: RequestChecksumCalculation::WhenRequired,
parallel_upload_semaphore: Arc::new(Semaphore::new(1)),
}
}
pub fn build_tracing_config(&self) -> Option<TracingConfig> {
self.verbosity
.log_level()
.map(|tracing_level| TracingConfig {
tracing_level,
json_tracing: self.json_tracing,
aws_sdk_tracing: self.aws_sdk_tracing,
span_events_tracing: self.span_events_tracing,
disable_color_tracing: self.disable_color_tracing,
})
}
pub fn build_tracing_config_dry_run(&self, dry_run: bool) -> Option<TracingConfig> {
if !dry_run {
return self.build_tracing_config();
}
let tracing_level = self
.verbosity
.log_level()
.map_or(log::Level::Info, |l| l.max(log::Level::Info));
Some(TracingConfig {
tracing_level,
json_tracing: self.json_tracing,
aws_sdk_tracing: self.aws_sdk_tracing,
span_events_tracing: self.span_events_tracing,
disable_color_tracing: self.disable_color_tracing,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use clap::Parser;
#[derive(Parser, Debug)]
struct TestCli {
#[command(flatten)]
common: CommonClientArgs,
}
#[test]
fn parses_with_no_flags() {
let cli = TestCli::try_parse_from(["test"]).unwrap();
assert!(cli.common.target_region.is_none());
assert_eq!(cli.common.aws_max_attempts, DEFAULT_AWS_MAX_ATTEMPTS);
assert!(!cli.common.json_tracing);
}
#[test]
fn target_access_key_requires_secret() {
let res = TestCli::try_parse_from(["test", "--target-access-key", "AKIA"]);
assert!(res.is_err(), "must require --target-secret-access-key");
}
#[test]
fn target_no_sign_request_conflicts_with_profile() {
let res = TestCli::try_parse_from([
"test",
"--target-no-sign-request",
"--target-profile",
"default",
]);
assert!(res.is_err(), "no-sign-request must conflict with profile");
}
#[test]
fn target_region_rejects_empty() {
let res = TestCli::try_parse_from(["test", "--target-region", ""]);
assert!(res.is_err(), "empty region must be rejected");
}
#[test]
fn build_client_config_uses_environment_credentials_by_default() {
let cli = TestCli::try_parse_from(["test"]).unwrap();
let cfg = cli.common.build_client_config();
assert!(matches!(
cfg.credential,
crate::types::S3Credentials::FromEnvironment
));
assert_eq!(cfg.retry_config.aws_max_attempts, DEFAULT_AWS_MAX_ATTEMPTS);
assert!(!cfg.disable_stalled_stream_protection);
}
#[test]
fn build_client_config_uses_no_sign_request_when_set() {
let cli = TestCli::try_parse_from(["test", "--target-no-sign-request"]).unwrap();
let cfg = cli.common.build_client_config();
assert!(matches!(
cfg.credential,
crate::types::S3Credentials::NoSignRequest
));
}
#[test]
fn build_client_config_uses_explicit_keys() {
let cli = TestCli::try_parse_from([
"test",
"--target-access-key",
"AKIA",
"--target-secret-access-key",
"secret",
])
.unwrap();
let cfg = cli.common.build_client_config();
assert!(matches!(
cfg.credential,
crate::types::S3Credentials::Credentials { .. }
));
}
#[test]
fn build_client_config_uses_profile_when_set() {
let cli = TestCli::try_parse_from(["test", "--target-profile", "prod"]).unwrap();
let cfg = cli.common.build_client_config();
match cfg.credential {
crate::types::S3Credentials::Profile(name) => assert_eq!(name, "prod"),
other => panic!("expected Profile, got {other:?}"),
}
}
#[test]
fn build_client_config_propagates_request_payer_and_accelerate() {
let cli =
TestCli::try_parse_from(["test", "--target-request-payer", "--target-accelerate"])
.unwrap();
let cfg = cli.common.build_client_config();
assert_eq!(
cfg.request_payer,
Some(aws_sdk_s3::types::RequestPayer::Requester)
);
assert!(cfg.accelerate);
}
#[test]
fn build_tracing_config_returns_some_at_default_verbosity() {
let cli = TestCli::try_parse_from(["test"]).unwrap();
let cfg = cli.common.build_tracing_config();
assert!(cfg.is_some());
let cfg = cfg.unwrap();
assert_eq!(cfg.tracing_level, log::Level::Warn);
assert!(!cfg.json_tracing);
}
#[test]
fn build_tracing_config_returns_none_when_silenced() {
let cli = TestCli::try_parse_from(["test", "-qqq"]).unwrap();
let cfg = cli.common.build_tracing_config();
assert!(cfg.is_none(), "expected None when verbosity below Error");
}
#[test]
fn build_tracing_config_propagates_flags() {
let cli = TestCli::try_parse_from([
"test",
"--json-tracing",
"--aws-sdk-tracing",
"--span-events-tracing",
"--disable-color-tracing",
])
.unwrap();
let cfg = cli.common.build_tracing_config().unwrap();
assert!(cfg.json_tracing);
assert!(cfg.aws_sdk_tracing);
assert!(cfg.span_events_tracing);
assert!(cfg.disable_color_tracing);
}
#[test]
fn build_tracing_config_dry_run_false_matches_normal() {
let cli = TestCli::try_parse_from(["test"]).unwrap();
let normal = cli.common.build_tracing_config();
let dry = cli.common.build_tracing_config_dry_run(false);
assert_eq!(
normal.map(|c| c.tracing_level),
dry.map(|c| c.tracing_level)
);
}
#[test]
fn build_tracing_config_dry_run_false_returns_none_when_silenced() {
let cli = TestCli::try_parse_from(["test", "-qqq"]).unwrap();
let cfg = cli.common.build_tracing_config_dry_run(false);
assert!(cfg.is_none());
}
#[test]
fn build_tracing_config_dry_run_bumps_default_warn_to_info() {
let cli = TestCli::try_parse_from(["test"]).unwrap();
let cfg = cli.common.build_tracing_config_dry_run(true).unwrap();
assert_eq!(cfg.tracing_level, log::Level::Info);
}
#[test]
fn build_tracing_config_dry_run_keeps_explicit_info() {
let cli = TestCli::try_parse_from(["test", "-v"]).unwrap();
let cfg = cli.common.build_tracing_config_dry_run(true).unwrap();
assert_eq!(cfg.tracing_level, log::Level::Info);
}
#[test]
fn build_tracing_config_dry_run_preserves_debug() {
let cli = TestCli::try_parse_from(["test", "-vv"]).unwrap();
let cfg = cli.common.build_tracing_config_dry_run(true).unwrap();
assert_eq!(
cfg.tracing_level,
log::Level::Debug,
"debug must not be downgraded to info"
);
}
#[test]
fn build_tracing_config_dry_run_preserves_trace() {
let cli = TestCli::try_parse_from(["test", "-vvv"]).unwrap();
let cfg = cli.common.build_tracing_config_dry_run(true).unwrap();
assert_eq!(
cfg.tracing_level,
log::Level::Trace,
"trace must not be downgraded to info"
);
}
#[test]
fn build_tracing_config_dry_run_bumps_error_to_info() {
let cli = TestCli::try_parse_from(["test", "-q"]).unwrap();
let cfg = cli.common.build_tracing_config_dry_run(true).unwrap();
assert_eq!(cfg.tracing_level, log::Level::Info);
}
#[test]
fn build_tracing_config_dry_run_bumps_silenced_to_info() {
let cli = TestCli::try_parse_from(["test", "-qqq"]).unwrap();
let cfg = cli.common.build_tracing_config_dry_run(true);
assert!(cfg.is_some(), "dry-run must override -qqq silencing");
assert_eq!(cfg.unwrap().tracing_level, log::Level::Info);
}
#[test]
fn build_tracing_config_dry_run_propagates_format_flags() {
let cli = TestCli::try_parse_from([
"test",
"--json-tracing",
"--aws-sdk-tracing",
"--span-events-tracing",
"--disable-color-tracing",
"-qqq",
])
.unwrap();
let cfg = cli.common.build_tracing_config_dry_run(true).unwrap();
assert!(cfg.json_tracing);
assert!(cfg.aws_sdk_tracing);
assert!(cfg.span_events_tracing);
assert!(cfg.disable_color_tracing);
assert_eq!(cfg.tracing_level, log::Level::Info);
}
}