use std::fmt;
use std::fs::File;
use std::io::BufReader;
use std::num::NonZero;
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use std::thread::available_parallelism;
use std::time::Duration;
use base64::Engine;
use base64::prelude::BASE64_STANDARD;
use secrecy::SecretString;
use serde::Deserialize;
use sha2::Digest;
use tokio::select;
use tokio::task::spawn_blocking;
use tokio_retry2::strategy::ExponentialFactorBackoff;
use tokio_retry2::strategy::MaxInterval;
use tokio_util::sync::CancellationToken;
use tracing::info;
const DEFAULT_RETRIES: usize = 5;
const DEFAULT_REGION: &str = "us-east-1";
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash, Deserialize)]
#[serde(rename_all = "snake_case")]
#[cfg_attr(feature = "cli", derive(clap::ValueEnum))]
pub enum HashAlgorithm {
None,
#[default]
Sha256,
Blake3,
}
impl HashAlgorithm {
pub async fn calculate_content_digest(
&self,
path: &Path,
cancel: &CancellationToken,
) -> crate::Result<Option<String>> {
info!(
"calculating content digest for file `{path}`",
path = path.display()
);
match self {
Self::None => Ok(None),
Self::Sha256 => {
let path = path.to_path_buf();
let fut = spawn_blocking(move || {
let mut hasher = sha2::Sha256::new();
let mut reader = BufReader::new(File::open(path)?);
std::io::copy(&mut reader, &mut hasher)?;
let digest = hasher.finalize();
Ok(Some(format!(
"sha-256=:{encoded}:",
encoded = BASE64_STANDARD.encode(digest)
)))
});
select! {
_ = cancel.cancelled() => Err(crate::Error::Canceled),
r = fut => r.expect("failed to join task")
}
}
Self::Blake3 => {
let path = path.to_path_buf();
let fut = spawn_blocking(move || {
let mut hasher = blake3::Hasher::new();
hasher.update_mmap_rayon(&path)?;
let digest = hasher.finalize();
Ok(Some(format!(
"blake3=:{encoded}:",
encoded = BASE64_STANDARD.encode(digest.as_bytes())
)))
});
select! {
_ = cancel.cancelled() => Err(crate::Error::Canceled),
r = fut => r.expect("failed to join task")
}
}
}
}
}
impl FromStr for HashAlgorithm {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"none" => Ok(Self::None),
"sha256" => Ok(Self::Sha256),
"blake3" => Ok(Self::Blake3),
_ => Err(format!("invalid digest algorithm `{s}`")),
}
}
}
impl fmt::Display for HashAlgorithm {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::None => write!(f, "none"),
Self::Sha256 => write!(f, "sha256"),
Self::Blake3 => write!(f, "blake3"),
}
}
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct AzureAuthConfig {
account_name: String,
access_key: SecretString,
}
impl AzureAuthConfig {
pub fn account_name(&self) -> &str {
&self.account_name
}
pub fn access_key(&self) -> &SecretString {
&self.access_key
}
}
#[derive(Debug, Clone, Default, Deserialize)]
pub struct AzureConfig {
#[serde(default)]
auth: Option<AzureAuthConfig>,
#[serde(default)]
use_azurite: bool,
}
impl AzureConfig {
pub fn with_auth(
mut self,
account_name: impl Into<String>,
access_key: impl Into<SecretString>,
) -> Self {
self.auth = Some(AzureAuthConfig {
account_name: account_name.into(),
access_key: access_key.into(),
});
self
}
pub fn with_use_azurite(mut self, use_azurite: bool) -> Self {
self.use_azurite = use_azurite;
self
}
pub fn auth(&self) -> Option<&AzureAuthConfig> {
self.auth.as_ref()
}
pub fn use_azurite(&self) -> bool {
self.use_azurite
}
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct S3AuthConfig {
access_key_id: String,
secret_access_key: SecretString,
}
impl S3AuthConfig {
pub fn access_key_id(&self) -> &str {
&self.access_key_id
}
pub fn secret_access_key(&self) -> &SecretString {
&self.secret_access_key
}
}
#[derive(Debug, Clone, Default, Deserialize)]
pub struct S3Config {
#[serde(default)]
region: Option<String>,
#[serde(default)]
auth: Option<S3AuthConfig>,
#[serde(default)]
use_localstack: bool,
}
impl S3Config {
pub fn with_region(mut self, region: impl Into<String>) -> Self {
self.region = Some(region.into());
self
}
pub fn with_maybe_region(mut self, region: Option<String>) -> Self {
self.region = region;
self
}
pub fn with_auth(
mut self,
access_key_id: impl Into<String>,
secret_access_key: impl Into<SecretString>,
) -> Self {
self.auth = Some(S3AuthConfig {
access_key_id: access_key_id.into(),
secret_access_key: secret_access_key.into(),
});
self
}
pub fn with_use_localstack(mut self, use_localstack: bool) -> Self {
self.use_localstack = use_localstack;
self
}
pub fn region(&self) -> &str {
self.region.as_deref().unwrap_or(DEFAULT_REGION)
}
pub fn auth(&self) -> Option<&S3AuthConfig> {
self.auth.as_ref()
}
pub fn use_localstack(&self) -> bool {
self.use_localstack
}
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct GoogleAuthConfig {
access_key: String,
secret: SecretString,
}
impl GoogleAuthConfig {
pub fn access_key(&self) -> &str {
&self.access_key
}
pub fn secret(&self) -> &SecretString {
&self.secret
}
}
#[derive(Debug, Clone, Default, Deserialize)]
pub struct GoogleConfig {
#[serde(default)]
auth: Option<GoogleAuthConfig>,
}
impl GoogleConfig {
pub fn with_auth(
mut self,
access_key: impl Into<String>,
secret: impl Into<SecretString>,
) -> Self {
self.auth = Some(GoogleAuthConfig {
access_key: access_key.into(),
secret: secret.into(),
});
self
}
pub fn auth(&self) -> Option<&GoogleAuthConfig> {
self.auth.as_ref()
}
}
#[derive(Debug, Default, Deserialize)]
struct BackendsConfig {
#[serde(default)]
azure: AzureConfig,
#[serde(default)]
s3: S3Config,
#[serde(default)]
google: GoogleConfig,
}
#[derive(Debug, Default)]
pub struct ConfigBuilder {
algorithm: HashAlgorithm,
link_to_cache: bool,
overwrite: bool,
block_size: Option<u64>,
parallelism: Option<usize>,
retries: Option<usize>,
backends: BackendsConfig,
}
impl ConfigBuilder {
pub fn with_hash_algorithm(mut self, algorithm: HashAlgorithm) -> Self {
self.algorithm = algorithm;
self
}
pub fn with_link_to_cache(mut self, link_to_cache: bool) -> Self {
self.link_to_cache = link_to_cache;
self
}
pub fn with_overwrite(mut self, overwrite: bool) -> Self {
self.overwrite = overwrite;
self
}
pub fn with_block_size(mut self, block_size: u64) -> Self {
self.block_size = Some(block_size);
self
}
pub fn with_maybe_block_size(mut self, block_size: Option<u64>) -> Self {
self.block_size = block_size;
self
}
pub fn with_parallelism(mut self, parallelism: usize) -> Self {
self.parallelism = Some(parallelism);
self
}
pub fn with_maybe_parallelism(mut self, parallelism: Option<usize>) -> Self {
self.parallelism = parallelism;
self
}
pub fn with_retries(mut self, retries: usize) -> Self {
self.retries = Some(retries);
self
}
pub fn with_maybe_retries(mut self, retries: Option<usize>) -> Self {
self.retries = retries;
self
}
pub fn with_azure(mut self, azure: AzureConfig) -> Self {
self.backends.azure = azure;
self
}
pub fn with_s3(mut self, s3: S3Config) -> Self {
self.backends.s3 = s3;
self
}
pub fn with_google(mut self, google: GoogleConfig) -> Self {
self.backends.google = google;
self
}
pub fn build(self) -> Config {
Config {
algorithm: self.algorithm,
link_to_cache: self.link_to_cache,
overwrite: self.overwrite,
block_size: self.block_size,
parallelism: self.parallelism,
retries: self.retries,
backends: Arc::new(self.backends),
}
}
}
#[derive(Debug, Clone, Default, Deserialize)]
pub struct Config {
#[serde(default)]
algorithm: HashAlgorithm,
#[serde(default)]
link_to_cache: bool,
#[serde(default)]
overwrite: bool,
#[serde(default)]
block_size: Option<u64>,
#[serde(default)]
parallelism: Option<usize>,
#[serde(default)]
retries: Option<usize>,
#[serde(default)]
backends: Arc<BackendsConfig>,
}
impl Config {
pub fn builder() -> ConfigBuilder {
ConfigBuilder::default()
}
pub fn hash_algorithm(&self) -> HashAlgorithm {
self.algorithm
}
pub fn set_hash_algorithm(&mut self, algorithm: HashAlgorithm) {
self.algorithm = algorithm;
}
pub fn link_to_cache(&self) -> bool {
self.link_to_cache
}
pub fn set_link_to_cache(&mut self, link_to_cache: bool) {
self.link_to_cache = link_to_cache;
}
pub fn overwrite(&self) -> bool {
self.overwrite
}
pub fn set_overwrite(&mut self, overwrite: bool) {
self.overwrite = overwrite;
}
pub fn block_size(&self) -> Option<u64> {
self.block_size
}
pub fn set_block_size(&mut self, block_size: u64) {
self.block_size = Some(block_size);
}
pub fn parallelism(&self) -> usize {
self.parallelism
.unwrap_or_else(|| available_parallelism().map(NonZero::get).unwrap_or(1) * 2)
}
pub fn set_parallelism(&mut self, parallelism: usize) {
self.parallelism = Some(parallelism);
}
pub fn retries(&self) -> usize {
self.retries.unwrap_or(DEFAULT_RETRIES)
}
pub fn set_retries(&mut self, retries: usize) {
self.retries = Some(retries);
}
pub fn azure(&self) -> &AzureConfig {
&self.backends.azure
}
pub fn s3(&self) -> &S3Config {
&self.backends.s3
}
pub fn google(&self) -> &GoogleConfig {
&self.backends.google
}
pub fn retry_durations<'a>(&self) -> impl Iterator<Item = Duration> + use<'a> {
const INITIAL_DELAY_MILLIS: u64 = 1000;
const BASE_FACTOR: f64 = 2.0;
const MAX_DURATION: Duration = Duration::from_secs(600);
ExponentialFactorBackoff::from_millis(INITIAL_DELAY_MILLIS, BASE_FACTOR)
.max_duration(MAX_DURATION)
.take(self.retries())
}
}