use crate::destination::{Destination, resolve_destination};
use crate::protocol::{
DestinationProtocol, DestinationWriter, SourceProtocol, SourceReader, TransferError,
};
use crate::source::{Source, resolve_source};
use futures_util::StreamExt;
use indicatif::{HumanBytes, HumanDuration, ProgressStyle};
use std::pin::pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tracing::{Instrument, Span};
use tracing_indicatif::span_ext::IndicatifSpanExt;
use url::Url;
const BACKOFF_BASE: Duration = Duration::from_secs(1);
const BACKOFF_MAX: Duration = Duration::from_secs(60);
pub(crate) fn backoff_delay(attempt: u32, server_hint: Duration) -> Duration {
let exp = BACKOFF_BASE.saturating_mul(1u32.checked_shl(attempt).unwrap_or(u32::MAX));
let jitter_frac = rand::random_range(0.75..=1.25);
let jittered_exp = Duration::from_secs_f64(exp.as_secs_f64() * jitter_frac);
let capped_exp = jittered_exp.min(BACKOFF_MAX);
capped_exp.max(server_hint)
}
pub struct TransferConfig {
pub max_retries: u32,
pub overwrite: bool,
pub custom_http_headers: Vec<(String, String)>,
}
pub struct ProgressState {
bytes_written: AtomicU64,
total_size: AtomicU64,
start_time: Instant,
}
impl Default for ProgressState {
fn default() -> Self {
Self::new()
}
}
impl ProgressState {
#[must_use]
pub fn new() -> Self {
Self {
bytes_written: AtomicU64::new(0),
total_size: AtomicU64::new(0),
start_time: Instant::now(),
}
}
fn update_bytes_written(&self, bytes: u64) {
self.bytes_written.store(bytes, Ordering::Relaxed);
}
fn update_total_size(&self, total: u64) {
self.total_size.store(total, Ordering::Relaxed);
}
pub fn bytes_written(&self) -> u64 {
self.bytes_written.load(Ordering::Relaxed)
}
pub fn total_size(&self) -> u64 {
self.total_size.load(Ordering::Relaxed)
}
pub fn elapsed(&self) -> Duration {
self.start_time.elapsed()
}
}
#[must_use]
#[expect(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
reason = "acceptable for human-readable progress display"
)]
pub fn format_progress_log(bytes_written: u64, total_size: u64, elapsed: Duration) -> String {
let speed = if elapsed.as_secs_f64() > 0.0 {
bytes_written as f64 / elapsed.as_secs_f64()
} else {
0.0
};
let speed_display = HumanBytes(speed as u64);
if total_size > 0 {
let pct = (bytes_written as f64 / total_size as f64) * 100.0;
let eta = if speed > 0.0 {
let remaining = total_size.saturating_sub(bytes_written);
let eta_secs = remaining as f64 / speed;
format!(", ETA {}", HumanDuration(Duration::from_secs_f64(eta_secs)))
} else {
String::new()
};
format!(
"Progress: {pct:.1}% ({} / {}) at {speed_display}/s{eta}",
HumanBytes(bytes_written),
HumanBytes(total_size),
)
} else {
format!(
"Progress: {} transferred at {speed_display}/s",
HumanBytes(bytes_written),
)
}
}
#[doc(hidden)]
#[macro_export]
macro_rules! retry_transient {
($max_retries:expr, $op:expr) => {{
let mut n_retries: u32 = 0;
loop {
match $op.await {
Ok(val) => break Ok(val),
Err($crate::protocol::TransferError::Transient {
minimum_retry_delay: server_hint,
reason,
..
}) => {
n_retries += 1;
if n_retries > $max_retries {
break Err($crate::protocol::TransferError::Permanent {
reason: format!(
"exhausted {} retries (last error: {reason})",
$max_retries
),
});
}
let delay = $crate::transfer::backoff_delay(n_retries - 1, server_hint);
tracing::warn!(
"Transient error on attempt {}/{}: {reason}. Retrying after {delay:?}.",
n_retries,
$max_retries
);
tokio::time::sleep(delay).await;
}
Err(e) => break Err(e),
}
}
}};
}
pub async fn execute_transfer(
source_url: Url,
dest_url: Url,
config: &TransferConfig,
progress: Option<Arc<ProgressState>>,
) -> Result<u64, TransferError> {
match (
resolve_source(&source_url, config)?,
resolve_destination(&dest_url, config)?,
) {
(Source::Http(mut src), Destination::File(dest)) => {
let writer = retry_transient!(3, dest.get_writer(dest_url.clone()))?;
run_transfer(&mut src, writer, source_url, config, progress)
.instrument(tracing::info_span!(
"transfer",
indicatif.pb_show = tracing::field::Empty
))
.await
}
}
}
#[expect(
clippy::missing_panics_doc,
reason = "unwrap on a hardcoded infallible template"
)]
#[expect(clippy::too_many_lines, reason = "core transfer orchestration loop")]
pub async fn run_transfer<S: SourceProtocol, W: DestinationWriter>(
source: &mut S,
mut writer: W,
source_url: Url,
config: &TransferConfig,
progress: Option<Arc<ProgressState>>,
) -> Result<u64, TransferError> {
let mut retry_count: u32 = 0;
let mut total_bytes_written: u64 = 0;
loop {
let (reader, read_start) = retry_transient!(
config.max_retries,
source.get_reader(source_url.clone(), total_bytes_written)
)?;
let span = Span::current();
if let Some(total) = read_start.total_size {
span.pb_set_style(
&ProgressStyle::with_template(
"{spinner:.green} [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})",
)
.unwrap()
.progress_chars("#>-"),
);
span.pb_set_length(total);
span.pb_set_position(total_bytes_written);
if let Some(ref ps) = progress {
ps.update_total_size(total);
ps.update_bytes_written(total_bytes_written);
}
} else {
span.pb_set_style(
&ProgressStyle::with_template("{spinner:.green} {bytes} ({bytes_per_sec})")
.unwrap(),
);
}
if read_start.offset != total_bytes_written {
if read_start.offset != 0 && retry_count > config.max_retries {
return Err(TransferError::Permanent {
reason: format!(
"Source streaming from offset {} but we requested to start from {total_bytes_written}. This is more degenerate than simply ignoring a streaming offset, so we can't reasonably recover.",
read_start.offset
),
});
}
tracing::info!(
"Source streaming from the start (we requested offset {total_bytes_written}). This source probably does not support range transfers. Restarting the transfer from the start."
);
retry_transient!(3, writer.truncate_and_reset())?;
total_bytes_written = 0;
Span::current().pb_set_position(0);
if let Some(ref ps) = progress {
ps.update_bytes_written(0);
}
}
let mut stream = pin!(reader.stream_bytes());
let mut stream_failed = false;
while let Some(result) = stream.next().await {
match result {
Ok(bytes) => match writer.write(&bytes).await {
Ok(()) => {
total_bytes_written += bytes.len() as u64;
Span::current().pb_set_position(total_bytes_written);
if let Some(ref ps) = progress {
ps.update_bytes_written(total_bytes_written);
}
}
Err(TransferError::Transient {
consumed_byte_count,
minimum_retry_delay: server_hint,
reason,
}) => {
total_bytes_written = consumed_byte_count;
Span::current().pb_set_position(total_bytes_written);
if let Some(ref ps) = progress {
ps.update_bytes_written(total_bytes_written);
}
retry_count += 1;
if retry_count > config.max_retries {
return Err(TransferError::Permanent {
reason: format!(
"exhausted {} retries (last error: {reason})",
config.max_retries
),
});
}
let delay = backoff_delay(retry_count - 1, server_hint);
tracing::warn!(
"Transient write error after {consumed_byte_count} bytes: {reason}. Will resume after {delay:?}."
);
tokio::time::sleep(delay).await;
stream_failed = true;
break;
}
Err(e @ TransferError::Permanent { .. }) => return Err(e),
},
Err(TransferError::Transient {
consumed_byte_count: _,
minimum_retry_delay: server_hint,
reason,
}) => {
retry_count += 1;
if retry_count > config.max_retries {
return Err(TransferError::Permanent {
reason: format!(
"exhausted {} retries (last error: {reason})",
config.max_retries
),
});
}
let delay = backoff_delay(retry_count - 1, server_hint);
tracing::warn!(
"Transient error during streaming on attempt {retry_count}/{}: {reason}. \
Retrying after {delay:?}.",
config.max_retries
);
tokio::time::sleep(delay).await;
stream_failed = true;
break;
}
Err(e @ TransferError::Permanent { .. }) => return Err(e),
}
}
if !stream_failed {
break;
}
}
writer.finalize().await?;
tracing::info!("Transfer complete: {total_bytes_written} bytes written.");
Ok(total_bytes_written)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn backoff_increases_exponentially() {
let hint = Duration::ZERO;
let d0 = backoff_delay(0, hint);
let d1 = backoff_delay(1, hint);
let d2 = backoff_delay(2, hint);
let d3 = backoff_delay(3, hint);
assert!(d0 >= Duration::from_millis(750), "d0={d0:?}");
assert!(d1 > d0, "d1={d1:?} should be > d0={d0:?}");
assert!(d2 > d1, "d2={d2:?} should be > d1={d1:?}");
assert!(d3 > d2, "d3={d3:?} should be > d2={d2:?}");
}
#[test]
fn backoff_respects_server_hint_as_floor() {
let server_hint = Duration::from_secs(30);
let delay = backoff_delay(0, server_hint);
assert!(
delay >= server_hint,
"delay {delay:?} must be >= server hint {server_hint:?}"
);
}
#[test]
fn backoff_caps_at_maximum() {
let delay = backoff_delay(20, Duration::ZERO);
assert!(
delay <= BACKOFF_MAX,
"delay {delay:?} exceeds maximum {BACKOFF_MAX:?}"
);
}
#[test]
fn backoff_honors_server_hint_beyond_cap() {
let huge_hint = Duration::from_secs(300);
let delay = backoff_delay(0, huge_hint);
assert!(
delay >= huge_hint,
"delay {delay:?} must honor server hint {huge_hint:?} even beyond cap"
);
}
#[test]
fn progress_log_with_known_total() {
let msg = format_progress_log(500_000_000, 1_000_000_000, Duration::from_secs(50));
assert!(msg.contains("50.0%"), "expected percentage, got: {msg}");
assert!(msg.contains("MiB"), "expected MiB unit, got: {msg}");
assert!(msg.contains("/s"), "expected speed, got: {msg}");
assert!(msg.contains("ETA"), "expected ETA, got: {msg}");
}
#[test]
fn progress_log_unknown_total() {
let msg = format_progress_log(500_000_000, 0, Duration::from_secs(50));
assert!(
msg.contains("transferred"),
"expected 'transferred', got: {msg}"
);
assert!(!msg.contains('%'), "should not have percentage, got: {msg}");
assert!(!msg.contains("ETA"), "should not have ETA, got: {msg}");
}
#[test]
fn progress_log_zero_elapsed() {
let msg = format_progress_log(0, 1000, Duration::ZERO);
assert!(msg.contains("0.0%"), "expected 0%, got: {msg}");
}
#[test]
fn progress_log_complete() {
let msg = format_progress_log(1000, 1000, Duration::from_secs(10));
assert!(msg.contains("100.0%"), "expected 100%, got: {msg}");
}
}