cloud-copy 0.8.0

A library for copying files to and from cloud storage.
//! Cloud storage copy utility.

use std::io::IsTerminal;
use std::io::stderr;
use std::path::PathBuf;
use std::str::FromStr;

use anyhow::Context;
use anyhow::Result;
use anyhow::bail;
use byte_unit::Byte;
use byte_unit::UnitType;
use chrono::Utc;
use clap::Parser;
use clap::ValueEnum;
use clap_verbosity_flag::Verbosity;
use clap_verbosity_flag::WarnLevel;
use cloud_copy::AzureConfig;
use cloud_copy::Config;
use cloud_copy::GoogleConfig;
use cloud_copy::HashAlgorithm;
use cloud_copy::HttpClient;
use cloud_copy::Location;
use cloud_copy::S3Config;
use cloud_copy::cli::TimeDeltaExt;
use cloud_copy::cli::handle_events;
use cloud_copy::copy;
use colored::Colorize;
use git_testament::git_testament;
use git_testament::render_testament;
use secrecy::SecretString;
use tokio::pin;
use tokio::sync::broadcast;
use tokio_util::sync::CancellationToken;
use tracing::level_filters::LevelFilter;
use tracing_indicatif::IndicatifLayer;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::layer::SubscriberExt;

git_testament!(TESTAMENT);

/// Represents the supported output color modes.
#[derive(Debug, Default, Clone, ValueEnum, Copy, PartialEq, Eq, Hash)]
pub enum ColorMode {
    /// Automatically colorize output depending on output device.
    #[default]
    Auto,
    /// Always colorize output.
    Always,
    /// Never colorize output.
    Never,
}

impl FromStr for ColorMode {
    type Err = anyhow::Error;

    fn from_str(s: &str) -> Result<Self> {
        match s {
            "auto" => Ok(Self::Auto),
            "always" => Ok(Self::Always),
            "never" => Ok(Self::Never),
            _ => bail!("invalid color mode `{s}`"),
        }
    }
}

impl std::fmt::Display for ColorMode {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::Auto => write!(f, "auto"),
            Self::Always => write!(f, "always"),
            Self::Never => write!(f, "never"),
        }
    }
}

/// A utility for transferring files to and from cloud storage services.
#[derive(Parser, Debug)]
#[command(version = render_testament!(TESTAMENT), propagate_version = true)]
struct Args {
    /// The source location to copy from.
    #[clap(value_name = "SOURCE")]
    source: String,

    /// The destination location to copy to.
    #[clap(value_name = "DESTINATION")]
    destination: String,

    /// The cache directory to use for downloads.
    #[clap(long, value_name = "DIR")]
    cache_dir: Option<PathBuf>,

    /// The hash algorithm to use for calculating content digests.
    ///
    /// Defaults to `sha256`.
    #[clap(long, value_name = "ALGO")]
    hash_algorithm: Option<HashAlgorithm>,

    /// Whether or not to create hard links to existing cached files.
    #[clap(long)]
    link_to_cache: bool,

    /// Whether or not to overwrite the destination.
    #[clap(long)]
    overwrite: bool,

    /// The block size to use for file transfers; the default block size depends
    /// on the cloud service.
    #[clap(long, value_name = "SIZE")]
    block_size: Option<u64>,

    /// The parallelism level for network operations; defaults to the host's
    /// available parallelism multiplied by 2.
    #[clap(long, value_name = "NUM")]
    parallelism: Option<usize>,

    /// The number of retries to attempt for network operations.
    #[clap(long, value_name = "RETRIES")]
    retries: Option<usize>,

    /// The Azure Storage Account Name to use.
    #[clap(long, env, value_name = "NAME", requires = "azure_access_key")]
    azure_account_name: Option<String>,

    /// The Azure Storage Access Key to use.
    #[clap(
        long,
        env,
        hide_env_values(true),
        value_name = "KEY",
        requires = "azure_account_name"
    )]
    azure_access_key: Option<SecretString>,

    /// The AWS Access Key ID to use.
    #[clap(long, env, value_name = "ID", requires = "aws_secret_access_key")]
    aws_access_key_id: Option<String>,

    /// The AWS Secret Access Key to use.
    #[clap(
        long,
        env,
        hide_env_values(true),
        value_name = "KEY",
        requires = "aws_access_key_id"
    )]
    aws_secret_access_key: Option<SecretString>,

    /// The default AWS region.
    #[clap(long, env, value_name = "REGION")]
    aws_default_region: Option<String>,

    /// The Google Cloud Storage HMAC access key to use.
    #[clap(long, env, value_name = "KEY", requires = "google_hmac_secret")]
    google_hmac_access_key: Option<String>,

    /// The Google Cloud Storage HMAC secret to use.
    #[clap(
        long,
        env,
        hide_env_values(true),
        value_name = "SECRET",
        requires = "google_hmac_access_key"
    )]
    google_hmac_secret: Option<SecretString>,

    /// The verbosity level.
    #[command(flatten)]
    verbosity: Verbosity<WarnLevel>,

    /// Controls output colorization.
    #[arg(long, default_value = "auto", global = true)]
    color: ColorMode,
}

impl Args {
    /// Converts the arguments into a `Config`, HTTP client, source, and
    /// destination.
    fn into_parts(self) -> (Config, HttpClient, String, String) {
        let azure = if let (Some(account_name), Some(access_key)) =
            (self.azure_account_name, self.azure_access_key)
        {
            AzureConfig::default().with_auth(account_name, access_key)
        } else {
            AzureConfig::default()
        };

        let s3 =
            if let (Some(id), Some(key)) = (self.aws_access_key_id, self.aws_secret_access_key) {
                S3Config::default().with_auth(id, key)
            } else {
                S3Config::default()
            }
            .with_maybe_region(self.aws_default_region);

        let google = if let (Some(access_key), Some(secret)) =
            (self.google_hmac_access_key, self.google_hmac_secret)
        {
            GoogleConfig::default().with_auth(access_key, secret)
        } else {
            GoogleConfig::default()
        };

        let config = Config::builder()
            .with_hash_algorithm(self.hash_algorithm.unwrap_or_default())
            .with_link_to_cache(self.link_to_cache)
            .with_overwrite(self.overwrite)
            .with_maybe_block_size(self.block_size)
            .with_maybe_parallelism(self.parallelism)
            .with_maybe_retries(self.retries)
            .with_azure(azure)
            .with_s3(s3)
            .with_google(google)
            .build();

        let client = self
            .cache_dir
            .map(|dir| HttpClient::new_with_cache(config.clone(), dir))
            .unwrap_or_default();

        (config, client, self.source, self.destination)
    }
}

/// Runs the application.
async fn run(cancel: CancellationToken) -> Result<()> {
    let args = Args::parse();

    let colorize = match args.color {
        ColorMode::Auto => stderr().is_terminal(),
        ColorMode::Always => true,
        ColorMode::Never => false,
    };

    // Try to get a default environment filter via `RUST_LOG`
    let env_filter = match EnvFilter::try_from_default_env()
        .context("invalid `RUST_LOG` environment variable")
    {
        Ok(filter) => filter,
        Err(e) => {
            // If there was an error and the variable was set, then the error was due to
            // parsing an invalid directive
            if std::env::var("RUST_LOG").is_ok() {
                return Err(e);
            }

            // Otherwise, use a default directive env filter that disables noisy hyper
            // output
            EnvFilter::builder()
                .with_default_directive(LevelFilter::from(args.verbosity).into())
                .from_env_lossy()
                .add_directive("hyper_util=off".parse()?)
                .add_directive("h2=off".parse()?)
        }
    };

    // Build the subscriber and set it as the global default
    let indicatif_layer = IndicatifLayer::new();
    let subscriber = tracing_subscriber::fmt::Subscriber::builder()
        .with_env_filter(env_filter)
        .with_writer(indicatif_layer.get_stderr_writer())
        .with_ansi(colorize)
        .finish()
        .with(indicatif_layer);

    colored::control::set_override(colorize);

    tracing::subscriber::set_global_default(subscriber)
        .context("failed to set tracing subscriber")?;

    // Only handle transfer events if for a terminal to display the progress
    let (events_tx, events_rx) = broadcast::channel(1000);
    let c = cancel.clone();
    let handler = tokio::spawn(async move { handle_events(events_rx, colorize, c).await });

    let start = Utc::now();

    let (config, client, source, destination) = args.into_parts();
    let result = copy(
        config,
        client,
        &source,
        &destination,
        cancel,
        Some(events_tx),
    )
    .await
    .with_context(|| {
        format!(
            "failed to copy `{source}` to `{destination}`",
            source = Location::new(&source),
            destination = Location::new(&destination),
        )
    });

    let end = Utc::now();

    let stats = handler.await.expect("failed to join events handler");

    // Print the statistics upon success
    if result.is_ok()
        && let Some(stats) = stats
    {
        let delta = end - start;
        let seconds = delta.num_seconds();

        println!(
            "{files} file{s} copied with a total of {bytes:#} transferred in {time} ({speed})",
            files = stats.files.to_string().cyan(),
            s = if stats.files == 1 { "" } else { "s" },
            bytes = format!(
                "{:#.3}",
                Byte::from_u64(stats.bytes).get_appropriate_unit(UnitType::Binary)
            )
            .cyan(),
            time = delta.english().to_string().cyan(),
            speed = format!(
                "{bytes:#.3}/s",
                bytes = if seconds == 0 || stats.bytes < 60 {
                    Byte::from_u64(stats.bytes)
                } else {
                    Byte::from_u64(stats.bytes / seconds as u64)
                }
                .get_appropriate_unit(UnitType::Binary)
            )
            .cyan()
        );
    }

    result
}

/// An async function that waits for a termination signal.
#[cfg(unix)]
async fn terminate(cancel: CancellationToken) {
    use tokio::select;
    use tokio::signal::unix::SignalKind;
    use tokio::signal::unix::signal;
    use tracing::info;

    let mut sigterm = signal(SignalKind::terminate()).expect("failed to create SIGTERM handler");
    let mut sigint = signal(SignalKind::interrupt()).expect("failed to create SIGINT handler");

    let signal = select! {
        _ = sigterm.recv() => "SIGTERM",
        _ = sigint.recv() => "SIGINT",
    };

    info!("received {signal} signal: initiating shutdown");
    cancel.cancel();
}

/// An async function that waits for a termination signal.
#[cfg(windows)]
async fn terminate(cancel: CancellationToken) {
    use tokio::signal::windows::ctrl_c;
    use tracing::info;

    let mut signal = ctrl_c().expect("failed to create ctrl-c handler");
    signal.recv().await;

    info!("received Ctrl-C signal: initiating shutdown");
    cancel.cancel();
}

#[tokio::main]
async fn main() {
    let cancel = CancellationToken::new();

    let run = run(cancel.clone());
    pin!(run);

    loop {
        tokio::select! {
            biased;
            _ = terminate(cancel.clone()) => continue,
            r = &mut run => {
                if let Err(e) = r {
                    eprintln!(
                        "{error}: {e:?}",
                        error = if std::io::stderr().is_terminal() {
                            "error".red().bold()
                        } else {
                            "error".normal()
                        }
                    );

                    std::process::exit(1);
                }

                break;
            }
        }
    }
}