#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
use clap::Parser;
use ripcurl::transfer::{ProgressState, format_progress_log};
use std::io::IsTerminal;
use std::path::PathBuf;
use std::process::ExitCode;
use std::sync::Arc;
use std::time::Duration;
use tracing_indicatif::IndicatifLayer;
use tracing_indicatif::filter::{IndicatifFilter, hide_indicatif_span_fields};
use tracing_subscriber::EnvFilter;
use tracing_subscriber::Layer;
use tracing_subscriber::fmt::format::DefaultFields;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use url::Url;
#[derive(Clone, Copy, Default, clap::ValueEnum)]
enum ProgressMode {
Bar,
Log,
#[default]
Auto,
None,
}
enum ResolvedProgressMode {
Bar,
Log,
None,
}
#[derive(Parser)]
#[command(name = "ripcurl", about, version)]
struct Cli {
source: String,
destination: String,
#[arg(long, help_heading = "General Options")]
overwrite: bool,
#[arg(long, default_value_t = 10, help_heading = "General Options")]
max_retries: u32,
#[arg(long, default_value_t, value_enum, help_heading = "General Options")]
progress: ProgressMode,
#[command(flatten)]
http: HttpOptions,
}
#[derive(clap::Args)]
#[command(next_help_heading = "HTTP Options")]
struct HttpOptions {
#[arg(short = 'H', long = "header")]
headers: Vec<String>,
}
const LOG_PROGRESS_INITIAL_DELAY: Duration = Duration::from_secs(10);
const LOG_PROGRESS_INTERVAL: Duration = Duration::from_secs(60);
#[tokio::main]
async fn main() -> ExitCode {
let cli = Cli::parse();
let progress_mode = resolve_progress_mode(cli.progress);
match progress_mode {
ResolvedProgressMode::Bar => {
let indicatif_layer = IndicatifLayer::new()
.with_span_field_formatter(hide_indicatif_span_fields(DefaultFields::new()));
tracing_subscriber::registry()
.with(EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")))
.with(
tracing_subscriber::fmt::layer()
.with_writer(indicatif_layer.get_stderr_writer()),
)
.with(indicatif_layer.with_filter(IndicatifFilter::new(false)))
.init();
}
ResolvedProgressMode::Log | ResolvedProgressMode::None => {
tracing_subscriber::registry()
.with(EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")))
.with(tracing_subscriber::fmt::layer())
.init();
}
}
let source_url = match parse_url(&cli.source) {
Ok(url) => url,
Err(e) => {
eprintln!("Invalid source URL: {e}");
return ExitCode::from(3);
}
};
let dest_url = match parse_url(&cli.destination) {
Ok(url) => url,
Err(e) => {
eprintln!("Invalid destination URL: {e}");
return ExitCode::from(3);
}
};
let custom_http_headers: Vec<(String, String)> = match cli
.http
.headers
.iter()
.map(|h| parse_header(h))
.collect::<Result<_, _>>()
{
Ok(h) => h,
Err(e) => {
eprintln!("Invalid header: {e}");
return ExitCode::from(3);
}
};
let config = ripcurl::transfer::TransferConfig {
max_retries: cli.max_retries,
overwrite: cli.overwrite,
custom_http_headers,
};
let progress_state = match progress_mode {
ResolvedProgressMode::Log => Some(Arc::new(ProgressState::new())),
ResolvedProgressMode::Bar | ResolvedProgressMode::None => None,
};
let log_handle = progress_state.as_ref().map(|ps| {
let ps = Arc::clone(ps);
tokio::spawn(async move {
tokio::time::sleep(LOG_PROGRESS_INITIAL_DELAY).await;
loop {
let bytes = ps.bytes_written();
let total = ps.total_size();
let elapsed = ps.elapsed();
tracing::info!("{}", format_progress_log(bytes, total, elapsed));
tokio::time::sleep(LOG_PROGRESS_INTERVAL).await;
}
})
});
let result =
ripcurl::transfer::execute_transfer(source_url, dest_url, &config, progress_state).await;
if let Some(handle) = log_handle {
handle.abort();
}
match result {
Ok(_bytes) => ExitCode::SUCCESS,
Err(
ripcurl::protocol::TransferError::Permanent { reason }
| ripcurl::protocol::TransferError::Transient { reason, .. },
) => {
eprintln!("Transfer failed: {reason}");
ExitCode::FAILURE
}
}
}
fn resolve_progress_mode(mode: ProgressMode) -> ResolvedProgressMode {
match mode {
ProgressMode::Bar => ResolvedProgressMode::Bar,
ProgressMode::Log => ResolvedProgressMode::Log,
ProgressMode::None => ResolvedProgressMode::None,
ProgressMode::Auto => {
if std::io::stderr().is_terminal() && !is_ci() {
ResolvedProgressMode::Bar
} else if is_ci() {
ResolvedProgressMode::Log
} else {
ResolvedProgressMode::None
}
}
}
}
fn is_ci() -> bool {
std::env::var_os("CI").is_some()
}
fn parse_header(raw: &str) -> Result<(String, String), String> {
let (name, value) = raw
.split_once(':')
.ok_or_else(|| format!("expected \"Name: Value\" format, got \"{raw}\""))?;
let name = name.trim();
if name.is_empty() {
return Err(format!("header name cannot be empty in \"{raw}\""));
}
Ok((name.to_string(), value.trim().to_string()))
}
fn parse_url(input: &str) -> Result<Url, String> {
if input.contains("://") {
return Url::parse(input).map_err(|e| e.to_string());
}
let path = PathBuf::from(input);
let abs_path = if path.is_absolute() {
path
} else {
std::env::current_dir()
.map_err(|e| format!("could not resolve current directory: {e}"))?
.join(path)
};
Url::from_file_path(&abs_path)
.map_err(|()| format!("could not convert path to URL: {}", abs_path.display()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_header_basic() {
assert_eq!(
parse_header("Authorization: Bearer tok123").unwrap(),
("Authorization".to_string(), "Bearer tok123".to_string())
);
}
#[test]
fn parse_header_colons_in_value() {
assert_eq!(
parse_header("X-Data: has:colons:in:value").unwrap(),
("X-Data".to_string(), "has:colons:in:value".to_string())
);
}
#[test]
fn parse_header_trims_whitespace() {
assert_eq!(
parse_header(" Name : value ").unwrap(),
("Name".to_string(), "value".to_string())
);
}
#[test]
fn parse_header_empty_value() {
assert_eq!(
parse_header("Name:").unwrap(),
("Name".to_string(), String::new())
);
}
#[test]
fn parse_header_missing_colon() {
assert!(parse_header("InvalidNoColon").is_err());
}
#[test]
fn parse_header_empty_name() {
assert!(parse_header(": value").is_err());
}
}