use crate::block_device::get_block_device_size;
use aws_sdk_ebs::primitives::ByteStream;
use aws_sdk_ebs::types::{ChecksumAggregationMethod, ChecksumAlgorithm, Tag};
use aws_sdk_ebs::Client as EbsClient;
use base64::engine::general_purpose::STANDARD as base64_engine;
use base64::Engine as _;
use bytes::BytesMut;
use futures::stream::{self, StreamExt};
use indicatif::ProgressBar;
use log::{debug, info, warn};
use sha2::{Digest, Sha256};
use snafu::{ensure, OptionExt, ResultExt, Snafu};
use std::cmp;
use std::collections::BTreeMap;
use std::convert::TryFrom;
use std::ffi::OsStr;
use std::io::SeekFrom;
use std::os::unix::fs::FileTypeExt;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicI32, AtomicU64, Ordering as AtomicOrdering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::fs::{self, File};
use tokio::io::{AsyncReadExt, AsyncSeekExt};
use tokio::time;
#[derive(Debug, Snafu)]
pub struct Error(error::Error);
type Result<T> = std::result::Result<T, Error>;
const GIBIBYTE: i64 = 1024 * 1024 * 1024;
const SNAPSHOT_BLOCK_WORKERS: usize = 64;
const SNAPSHOT_BLOCK_RETRY_SCALE: u64 = 2;
const SNAPSHOT_BLOCK_ATTEMPTS: u64 = 5;
const SNAPSHOT_TIMEOUT_MINUTES: i32 = 10;
const SHA256_ALGORITHM: ChecksumAlgorithm = ChecksumAlgorithm::ChecksumAlgorithmSha256;
const LINEAR_METHOD: ChecksumAggregationMethod =
ChecksumAggregationMethod::ChecksumAggregationLinear;
struct UploadStats {
buckets: [AtomicU64; 6],
errors: AtomicU64,
}
impl UploadStats {
fn new() -> Self {
Self {
buckets: Default::default(),
errors: AtomicU64::new(0),
}
}
fn record_success(&self, elapsed: Duration) {
let bucket = match elapsed.as_millis() {
0..250 => 0,
250..500 => 1,
500..1000 => 2,
1000..2000 => 3,
2000..5000 => 4,
_ => 5,
};
self.buckets[bucket].fetch_add(1, AtomicOrdering::Relaxed);
}
fn record_error(&self) {
self.errors.fetch_add(1, AtomicOrdering::Relaxed);
}
fn report(&self) {
let b: Vec<u64> = self
.buckets
.iter()
.map(|a| a.load(AtomicOrdering::Relaxed))
.collect();
let e = self.errors.load(AtomicOrdering::Relaxed);
info!(
"Upload complete: <250ms={} 250-500ms={} 500ms-1s={} 1-2s={} 2-5s={} >5s={} errors={}",
b[0], b[1], b[2], b[3], b[4], b[5], e
);
}
}
#[derive(Copy, Clone)]
pub enum ZeroBlocks {
Include,
Omit,
}
pub struct SnapshotUploader {
ebs_clients: Vec<EbsClient>,
}
impl SnapshotUploader {
pub fn new(ebs_client: EbsClient) -> Self {
SnapshotUploader {
ebs_clients: vec![ebs_client],
}
}
pub fn with_client_shards(ebs_clients: Vec<EbsClient>) -> Self {
assert!(!ebs_clients.is_empty(), "need at least one EBS client");
SnapshotUploader { ebs_clients }
}
fn client_for_block(&self, block_index: i32) -> &EbsClient {
&self.ebs_clients[block_index as usize % self.ebs_clients.len()]
}
#[allow(clippy::too_many_arguments)]
pub async fn upload_from_file<P: AsRef<Path>>(
&self,
path: P,
volume_size: Option<i64>,
description: Option<&str>,
tags: Option<Vec<Tag>>,
progress_bar: Option<ProgressBar>,
zero_blocks: Option<ZeroBlocks>,
kms_key_id: Option<String>,
workers: Option<usize>,
) -> Result<String> {
let path = path.as_ref();
let description = description.map(|s| s.to_string()).unwrap_or_else(|| {
path.file_name()
.unwrap_or_else(|| OsStr::new(""))
.to_string_lossy()
.to_string()
});
let file_meta = fs::metadata(path)
.await
.context(error::ReadFileMetadataSnafu { path })?;
let file_size = if file_meta.file_type().is_block_device() {
get_block_device_size(path).context(error::GetBlockDeviceSizeSnafu)?
} else {
self.file_size(&file_meta).await?
};
let min_volume_size = cmp::max((file_size + GIBIBYTE - 1) / GIBIBYTE, 1);
let volume_size = volume_size.unwrap_or(min_volume_size);
ensure!(
volume_size >= min_volume_size,
error::BadVolumeSizeSnafu {
requested: volume_size,
needed: min_volume_size,
}
);
debug!("Uploading {volume_size}G to snapshot...");
let (snapshot_id, block_size) = self
.start_snapshot(volume_size, description, tags, kms_key_id)
.await?;
let file_blocks = (file_size + i64::from(block_size - 1)) / i64::from(block_size);
let file_blocks =
i32::try_from(file_blocks).with_context(|_| error::ConvertNumberSnafu {
what: "calculate file blocks",
number: file_blocks.to_string(),
target: "i32",
})?;
let changed_blocks_count = Arc::new(AtomicI32::new(0));
let block_digests = Arc::new(Mutex::new(BTreeMap::new()));
let block_errors = Arc::new(Mutex::new(BTreeMap::new()));
let progress_bar = match progress_bar {
Some(pb) => {
let pb_length = file_blocks;
let pb_length =
u64::try_from(pb_length).with_context(|_| error::ConvertNumberSnafu {
what: "progress bar length",
number: pb_length.to_string(),
target: "u64",
})?;
pb.set_length(pb_length);
Arc::new(Some(pb))
}
None => Arc::new(None),
};
let zero_blocks = zero_blocks.unwrap_or(ZeroBlocks::Include);
let mut block_contexts = Vec::new();
let mut remaining_data = file_size;
for i in 0..file_blocks {
let data_length = cmp::min(i64::from(block_size), remaining_data);
let data_length =
usize::try_from(data_length).with_context(|_| error::ConvertNumberSnafu {
what: "data length",
number: data_length.to_string(),
target: "usize",
})?;
block_contexts.push(BlockContext {
path: PathBuf::from(path),
data_length,
block_index: i,
block_size,
snapshot_id: snapshot_id.clone(),
changed_blocks_count: Arc::clone(&changed_blocks_count),
block_digests: Arc::clone(&block_digests),
block_errors: Arc::clone(&block_errors),
progress_bar: Arc::clone(&progress_bar),
ebs_client: self.client_for_block(i).clone(),
zero_blocks,
});
remaining_data -= i64::from(block_size);
}
let worker_count = workers.unwrap_or(SNAPSHOT_BLOCK_WORKERS);
ensure!(worker_count > 0, error::InvalidWorkerCountSnafu);
debug!(
"Using {} concurrent upload workers across {} client shards",
worker_count,
self.ebs_clients.len()
);
let stats = Arc::new(UploadStats::new());
let upload = stream::iter(block_contexts).for_each_concurrent(worker_count, |context| {
let stats = Arc::clone(&stats);
async move {
for attempt in 0..SNAPSHOT_BLOCK_ATTEMPTS {
if attempt > 0 {
let backoff = Duration::from_secs(attempt * SNAPSHOT_BLOCK_RETRY_SCALE);
debug!(
"block {}: retry {}/{}, backoff {}s",
context.block_index,
attempt,
SNAPSHOT_BLOCK_ATTEMPTS,
backoff.as_secs()
);
time::sleep(backoff).await;
}
let start = std::time::Instant::now();
let block_result = self.upload_block(&context).await;
let elapsed = start.elapsed();
let mut block_errors = context.block_errors.lock().expect("poisoned");
if let Err(e) = block_result {
stats.record_error();
warn!(
"block {}: attempt {}/{} failed after {:.1}s: {}",
context.block_index,
attempt + 1,
SNAPSHOT_BLOCK_ATTEMPTS,
elapsed.as_secs_f64(),
e
);
block_errors.insert(context.block_index, e);
continue;
}
stats.record_success(elapsed);
block_errors.remove(&context.block_index);
break;
}
}
});
upload.await;
stats.report();
let block_errors = Arc::try_unwrap(block_errors)
.expect("referenced")
.into_inner()
.expect("poisoned");
let block_errors_count = block_errors.keys().len();
if block_errors_count != 0 {
let error_report: String = block_errors.values().map(|e| e.to_string()).collect();
error::PutSnapshotBlocksSnafu {
error_count: block_errors_count,
snapshot_id: snapshot_id.clone(),
error_report,
}
.fail()?;
}
let changed_blocks_count = changed_blocks_count.load(AtomicOrdering::Relaxed);
let block_digests = Arc::try_unwrap(block_digests)
.expect("referenced")
.into_inner()
.expect("poisoned");
let mut full_digest = Sha256::new();
for (_, hash_bytes) in block_digests {
full_digest.update(&hash_bytes);
}
let full_hash = base64_engine.encode(full_digest.finalize());
self.complete_snapshot(&snapshot_id, changed_blocks_count, &full_hash)
.await?;
Ok(snapshot_id)
}
async fn file_size(&self, file_meta: &std::fs::Metadata) -> Result<i64> {
let file_len = file_meta.len();
let file_len = i64::try_from(file_len).with_context(|_| error::ConvertNumberSnafu {
what: "file length",
number: file_len.to_string(),
target: "i64",
})?;
Ok(file_len)
}
async fn start_snapshot(
&self,
volume_size: i64,
description: String,
tags: Option<Vec<Tag>>,
kms_key_id: Option<String>,
) -> Result<(String, i32)> {
let mut request = self.ebs_clients[0]
.start_snapshot()
.volume_size(volume_size)
.set_description(Some(description))
.set_tags(tags)
.set_timeout(Some(SNAPSHOT_TIMEOUT_MINUTES));
if let Some(kms_key_id) = kms_key_id {
request = request
.set_encrypted(Some(true))
.set_kms_key_arn(Some(kms_key_id));
}
let start_response = request.send().await.context(error::StartSnapshotSnafu)?;
let snapshot_id = start_response
.snapshot_id
.context(error::FindSnapshotIdSnafu)?;
let block_size = start_response
.block_size
.context(error::FindSnapshotBlockSizeSnafu)?;
Ok((snapshot_id, block_size))
}
async fn complete_snapshot(
&self,
snapshot_id: &str,
changed_blocks_count: i32,
checksum: &str,
) -> Result<()> {
self.ebs_clients[0]
.complete_snapshot()
.snapshot_id(snapshot_id)
.changed_blocks_count(changed_blocks_count)
.set_checksum(Some(checksum.to_string()))
.set_checksum_algorithm(Some(SHA256_ALGORITHM))
.set_checksum_aggregation_method(Some(LINEAR_METHOD))
.send()
.await
.context(error::CompleteSnapshotSnafu { snapshot_id })?;
Ok(())
}
async fn upload_block(&self, context: &BlockContext) -> Result<()> {
let path: &Path = context.path.as_ref();
let mut f = File::open(path)
.await
.context(error::OpenFileSnafu { path })?;
let block_index_u64: u64 =
u64::try_from(context.block_index).with_context(|_| error::ConvertNumberSnafu {
what: "block_index",
number: context.block_index.to_string(),
target: "u64",
})?;
let block_size_u64: u64 =
u64::try_from(context.block_size).with_context(|_| error::ConvertNumberSnafu {
what: "block_size",
number: context.block_size.to_string(),
target: "u64",
})?;
let offset: u64 = block_index_u64
.checked_mul(block_size_u64)
.with_context(|| error::CheckedMultiplicationSnafu {
right: "block_size",
right_number: context.block_size.to_string(),
left: "block_index",
left_number: context.block_index.to_string(),
target: "u64",
})?;
f.seek(SeekFrom::Start(offset))
.await
.context(error::SeekFileOffsetSnafu { path, offset })?;
let block_size = context.block_size;
let block_size =
usize::try_from(block_size).with_context(|_| error::ConvertNumberSnafu {
what: "block size",
number: block_size.to_string(),
target: "usize",
})?;
let mut block = BytesMut::with_capacity(block_size);
let count = context.data_length;
block.resize(count, 0x0);
f.read_exact(block.as_mut())
.await
.context(error::ReadFileBytesSnafu {
path,
count,
offset,
})?;
if let ZeroBlocks::Omit = context.zero_blocks {
let sparse = block.iter().all(|&byte| byte == 0u8);
if sparse {
if let Some(ref progress_bar) = *context.progress_bar {
progress_bar.inc(1);
}
return Ok(());
}
}
if block.len() < block_size {
block.resize(block_size, 0x0);
}
let mut block_digest = Sha256::new();
block_digest.update(&block);
let hash_bytes = block_digest.finalize();
let block_hash = base64_engine.encode(hash_bytes);
let snapshot_id = &context.snapshot_id;
let block_index = context.block_index;
let data_length = block.len();
let data_length =
i32::try_from(data_length).with_context(|_| error::ConvertNumberSnafu {
what: "data length",
number: data_length.to_string(),
target: "i32",
})?;
context
.ebs_client
.put_snapshot_block()
.snapshot_id(snapshot_id.to_string())
.block_index(block_index)
.block_data(ByteStream::from(block.freeze()))
.data_length(data_length)
.checksum(block_hash)
.checksum_algorithm(SHA256_ALGORITHM)
.send()
.await
.context(error::PutSnapshotBlockSnafu {
snapshot_id,
block_index,
})?;
let mut block_digests = context.block_digests.lock().expect("poisoned");
block_digests.insert(block_index, hash_bytes.to_vec());
let changed_blocks_count = &context.changed_blocks_count;
changed_blocks_count.fetch_add(1, AtomicOrdering::Relaxed);
if let Some(ref progress_bar) = *context.progress_bar {
progress_bar.inc(1);
}
Ok(())
}
}
struct BlockContext {
path: PathBuf,
data_length: usize,
block_index: i32,
block_size: i32,
snapshot_id: String,
changed_blocks_count: Arc<AtomicI32>,
block_digests: Arc<Mutex<BTreeMap<i32, Vec<u8>>>>,
block_errors: Arc<Mutex<BTreeMap<i32, Error>>>,
progress_bar: Arc<Option<ProgressBar>>,
ebs_client: EbsClient,
zero_blocks: ZeroBlocks,
}
mod error {
use aws_sdk_ebs::operation::{
complete_snapshot::CompleteSnapshotError, put_snapshot_block::PutSnapshotBlockError,
start_snapshot::StartSnapshotError,
};
use snafu::Snafu;
use std::path::PathBuf;
#[derive(Debug, Snafu)]
#[snafu(visibility(pub(super)))]
pub(super) enum Error {
#[snafu(display("Failed to read metadata for '{}': {}", path.display(), source))]
ReadFileMetadata {
path: PathBuf,
source: std::io::Error,
},
#[snafu(display("{}", source))]
GetBlockDeviceSize { source: crate::block_device::Error },
#[snafu(display(
"Bad volume size: requested {} GiB, needed at least {} GiB",
requested,
needed
))]
BadVolumeSize { requested: i64, needed: i64 },
#[snafu(display("Failed to open '{}': {}", path.display(), source))]
OpenFile {
path: PathBuf,
source: std::io::Error,
},
#[snafu(display("Failed to seek to {} in '{}': {}", offset, path.display(), source))]
SeekFileOffset {
path: PathBuf,
offset: u64,
source: std::io::Error,
},
#[snafu(display("Failed to read {} bytes at offset {} from '{}': {}", count, offset, path.display(), source))]
ReadFileBytes {
path: PathBuf,
count: usize,
offset: u64,
source: std::io::Error,
},
#[snafu(display("Failed to start snapshot: {source}", source = crate::error_stack(&source, 2)))]
StartSnapshot {
#[snafu(source(from(aws_sdk_ebs::error::SdkError<StartSnapshotError>, Box::new)))]
source: Box<aws_sdk_ebs::error::SdkError<StartSnapshotError>>,
},
#[snafu(display(
"Failed to put block {} for snapshot '{}': {}",
block_index,
snapshot_id,
source
))]
PutSnapshotBlock {
snapshot_id: String,
block_index: i64,
#[snafu(source(from(aws_sdk_ebs::error::SdkError<PutSnapshotBlockError>, Box::new)))]
source: Box<aws_sdk_ebs::error::SdkError<PutSnapshotBlockError>>,
},
#[snafu(display(
"Failed to put {} blocks for snapshot '{}': {}",
error_count,
snapshot_id,
error_report
))]
PutSnapshotBlocks {
error_count: usize,
snapshot_id: String,
error_report: String,
},
#[snafu(display("Failed to complete snapshot '{}': {}", snapshot_id, source))]
CompleteSnapshot {
snapshot_id: String,
#[snafu(source(from(aws_sdk_ebs::error::SdkError<CompleteSnapshotError>, Box::new)))]
source: Box<aws_sdk_ebs::error::SdkError<CompleteSnapshotError>>,
},
#[snafu(display("Failed to find snapshot ID"))]
FindSnapshotId {},
#[snafu(display("Failed to find snapshot block size"))]
FindSnapshotBlockSize {},
#[snafu(display("Failed to convert {} {} to {}: {}", what, number, target, source))]
ConvertNumber {
what: String,
number: String,
target: String,
source: std::num::TryFromIntError,
},
#[snafu(display("Worker count must be greater than zero"))]
InvalidWorkerCount,
#[snafu(display(
"Overflowed multiplying {} ({}) and {} ({}) inside a {}",
left,
left_number,
right,
right_number,
target
))]
CheckedMultiplication {
left: String,
left_number: String,
right: String,
right_number: String,
target: String,
},
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn histogram_bucket_boundaries() {
let stats = UploadStats::new();
stats.record_success(Duration::from_millis(0));
stats.record_success(Duration::from_millis(249));
stats.record_success(Duration::from_millis(250));
stats.record_success(Duration::from_millis(499));
stats.record_success(Duration::from_millis(500));
stats.record_success(Duration::from_millis(999));
stats.record_success(Duration::from_millis(1000));
stats.record_success(Duration::from_millis(1999));
stats.record_success(Duration::from_millis(2000));
stats.record_success(Duration::from_millis(4999));
stats.record_success(Duration::from_millis(5000));
stats.record_success(Duration::from_millis(60000));
let b: Vec<u64> = stats
.buckets
.iter()
.map(|a| a.load(AtomicOrdering::Relaxed))
.collect();
assert_eq!(b[0], 2); assert_eq!(b[1], 2); assert_eq!(b[2], 2); assert_eq!(b[3], 2); assert_eq!(b[4], 2); assert_eq!(b[5], 2); }
#[test]
fn error_counter() {
let stats = UploadStats::new();
stats.record_error();
stats.record_error();
stats.record_error();
assert_eq!(stats.errors.load(AtomicOrdering::Relaxed), 3);
}
#[test]
fn client_for_block_modulo_logic() {
let num_shards = 3usize;
let expected = [0, 1, 2, 0, 1, 2, 0, 1, 2];
for (i, &want) in expected.iter().enumerate() {
assert_eq!(i % num_shards, want);
}
}
#[test]
#[should_panic(expected = "need at least one EBS client")]
fn with_client_shards_rejects_empty() {
SnapshotUploader::with_client_shards(vec![]);
}
}