use std::{
collections::{HashMap, VecDeque},
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
};
use tokio_util::sync::CancellationToken;
use bytes::Bytes;
use reqwest::{
Client,
header::{ACCEPT_ENCODING, HeaderValue, RANGE, USER_AGENT},
};
use tokio::{
fs,
sync::{Mutex, Notify, mpsc},
task::JoinSet,
time::{self, Duration, Instant},
};
use tracing::{Instrument, info_span};
use ulid::Ulid;
use prost::Message;
use crate::progress::{
DownloadContext, ProgressEvent, ProgressTracker, SAMPLE_INTERVAL, speed_window_rate,
trim_speed_window,
};
use crate::retry_policies::{FixedThenExponentialRetry, wait_for_retry};
use crate::{
download::Download,
download_manager::io::persist_encoded_metadata,
download_metadata::{DownloadMetadata, PartDetails},
error::{MetadataError, OdlError},
user_agents::random_user_agent,
};
const SPEED_WINDOW: Duration = Duration::from_millis(1500);
const MIN_DYNAMIC_SPLIT_SIZE: u64 = 3 * 1024 * 1024; #[cfg(not(test))]
const MIN_DYNAMIC_SPLIT_ETA: Duration = Duration::from_secs(60);
#[cfg(test)]
const MIN_DYNAMIC_SPLIT_ETA: Duration = Duration::from_secs(0);
#[cfg(not(test))]
const MIN_DYNAMIC_SPLIT_ELAPSED: Duration = Duration::from_secs(15);
#[cfg(test)]
const MIN_DYNAMIC_SPLIT_ELAPSED: Duration = Duration::from_millis(0);
#[cfg(not(test))]
const STALE_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10);
#[cfg(test)]
const STALE_CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
#[derive(Debug, Clone, Copy)]
pub struct RampupConfig {
pub enabled: bool,
pub batch_size: u64,
pub delay_min: Duration,
pub delay_max: Duration,
}
fn sample_rampup_delay(min: Duration, max: Duration) -> Duration {
use rand::RngExt;
if max <= min {
return min;
}
let lo = min.as_nanos().min(u64::MAX as u128) as u64;
let hi = max.as_nanos().min(u64::MAX as u128) as u64;
let n = rand::rng().random_range(lo..=hi);
Duration::from_nanos(n)
}
impl RampupConfig {
#[cfg(test)]
pub fn disabled() -> Self {
Self {
enabled: false,
batch_size: 1,
delay_min: Duration::ZERO,
delay_max: Duration::ZERO,
}
}
}
pub struct Downloader {
instruction: Arc<Download>,
metadata: Arc<Mutex<DownloadMetadata>>,
client: Arc<Client>,
randomize_user_agent: bool,
dynamic_split: bool,
rampup: RampupConfig,
speed_limiter: Option<Arc<BandwidthLimiter>>,
retry_policy: FixedThenExponentialRetry,
persist_mutex: Arc<Mutex<()>>,
ctx: DownloadContext,
tracker: Arc<ProgressTracker>,
active_parts: Arc<std::sync::Mutex<HashMap<String, Arc<PartController>>>>,
ramp_armed: std::sync::atomic::AtomicBool,
}
impl Downloader {
#[allow(clippy::too_many_arguments)]
pub fn new(
instruction: Arc<Download>,
metadata: DownloadMetadata,
client: Client,
randomize_user_agent: bool,
speed_limit: Option<u64>,
dynamic_split: bool,
rampup: RampupConfig,
retry_policy: FixedThenExponentialRetry,
ctx: DownloadContext,
) -> Self {
let concurrency_limit = metadata.max_connections as usize;
ctx.live.seed_if_unset(concurrency_limit.max(1));
let speed_limiter = speed_limit
.filter(|limit| *limit > 0)
.map(|limit| Arc::new(BandwidthLimiter::new(limit)));
let total = metadata.size;
let tracker = Arc::new(ProgressTracker::new(total));
let already_done: u64 = metadata
.parts
.values()
.filter(|p| p.finished)
.map(|p| p.size)
.sum();
if already_done > 0 {
tracker.advance(already_done);
}
Self {
instruction,
metadata: Arc::new(Mutex::new(metadata)),
client: Arc::new(client),
randomize_user_agent,
dynamic_split,
rampup,
speed_limiter,
retry_policy,
persist_mutex: Arc::new(Mutex::new(())),
ctx,
tracker,
active_parts: Arc::new(std::sync::Mutex::new(HashMap::new())),
ramp_armed: std::sync::atomic::AtomicBool::new(true),
}
}
pub async fn run(self) -> Result<DownloadMetadata, OdlError> {
self.seed_tracker_with_unfinished_parts().await;
let sampler_handle = self.spawn_speed_sampler();
let result = self.run_inner().await;
sampler_handle.abort();
result
}
async fn seed_tracker_with_unfinished_parts(&self) {
let parts: Vec<PartDetails> = {
let metadata = self.metadata.lock().await;
metadata
.parts
.values()
.filter(|p| !p.finished)
.cloned()
.collect()
};
let mut total_existing: u64 = 0;
for p in parts {
if let Ok(existing) = self.detect_existing_size(&p).await {
total_existing = total_existing.saturating_add(existing.min(p.size));
}
}
if total_existing > 0 {
self.tracker.advance(total_existing);
self.ctx.emit(ProgressEvent::Progress {
downloaded: self.tracker.downloaded(),
total: self.tracker.total(),
});
}
}
fn spawn_speed_sampler(&self) -> tokio::task::JoinHandle<()> {
let tracker = Arc::clone(&self.tracker);
let ctx = self.ctx.clone();
let active = Arc::clone(&self.active_parts);
tokio::spawn(async move {
let mut agg_window: VecDeque<(std::time::Instant, u64)> = VecDeque::new();
let mut part_windows: HashMap<String, VecDeque<(std::time::Instant, u64)>> =
HashMap::new();
agg_window.push_back((std::time::Instant::now(), tracker.downloaded()));
loop {
tokio::select! {
_ = ctx.cancel.cancelled() => return,
_ = time::sleep(SAMPLE_INTERVAL) => {}
}
let now = std::time::Instant::now();
let cur = tracker.downloaded();
agg_window.push_back((now, cur));
trim_speed_window(&mut agg_window, now, SPEED_WINDOW);
if let Some(bps) = speed_window_rate(&agg_window) {
ctx.emit(ProgressEvent::Speed {
bytes_per_second: bps,
});
}
ctx.emit(ProgressEvent::Progress {
downloaded: cur,
total: tracker.total(),
});
let snapshot: Vec<(String, Arc<PartController>)> = {
let map = active.lock().unwrap();
map.iter()
.map(|(k, v)| (k.clone(), Arc::clone(v)))
.collect()
};
let mut seen_parts = std::collections::HashSet::with_capacity(snapshot.len());
for (ulid, controller) in snapshot {
let part_cur = controller.downloaded();
let part_lim = controller.limit();
let win = part_windows.entry(ulid.clone()).or_default();
win.push_back((now, part_cur));
trim_speed_window(win, now, SPEED_WINDOW);
if let Some(bps) = speed_window_rate(win) {
ctx.emit(ProgressEvent::PartSpeed {
ulid: ulid.clone(),
bytes_per_second: bps,
});
}
ctx.emit(ProgressEvent::PartProgress {
ulid: ulid.clone(),
downloaded: part_cur,
total: part_lim,
});
seen_parts.insert(ulid);
}
part_windows.retain(|k, _| seen_parts.contains(k));
}
})
}
async fn run_inner(self) -> Result<DownloadMetadata, OdlError> {
let mut pending = self.pending_parts().await;
let mut active: HashMap<String, ActiveTask> = HashMap::new();
let mut join_set: JoinSet<Result<PartEvent, OdlError>> = JoinSet::new();
if let Some(first_part) = pending.pop_front() {
let probe = Arc::new(Notify::new());
self.schedule_part(first_part, &mut active, &mut join_set, Some(probe.clone()))
.await?;
tokio::select! {
_ = probe.notified() => {
self.ramp_armed.store(true, std::sync::atomic::Ordering::Relaxed);
}
maybe_res = join_set.join_next() => {
self.ramp_armed.store(false, std::sync::atomic::Ordering::Relaxed);
if let Some(res) = maybe_res {
self.handle_join_result_item(res, &mut pending, &mut active, &mut join_set).await?;
}
}
_ = self.ctx.cancel.cancelled() => {
join_set.shutdown().await;
return Err(OdlError::Cancelled);
}
}
}
self.fill_capacity(&mut pending, &mut active, &mut join_set)
.await?;
loop {
let live_changed = self.ctx.live.notified();
tokio::pin!(live_changed);
tokio::select! {
_ = self.ctx.cancel.cancelled() => {
join_set.shutdown().await;
return Err(OdlError::Cancelled);
}
_ = &mut live_changed => {
self.apply_live_cap(&mut active);
self.fill_capacity(&mut pending, &mut active, &mut join_set).await?;
}
next = join_set.join_next() => {
let Some(result) = next else { break };
self.handle_join_result_item(result, &mut pending, &mut active, &mut join_set)
.await?;
self.fill_capacity(&mut pending, &mut active, &mut join_set)
.await?;
}
}
}
let metadata_mutex = Arc::try_unwrap(self.metadata).map_err(|_| {
OdlError::MetadataError(MetadataError::Other {
message: "Failed to unwrap metadata Arc".to_string(),
})
})?;
Ok(metadata_mutex.into_inner())
}
async fn pending_parts(&self) -> VecDeque<PartDetails> {
let metadata = self.metadata.lock().await;
metadata
.parts
.values()
.filter(|p| !p.finished)
.cloned()
.collect()
}
async fn fill_capacity(
&self,
pending: &mut VecDeque<PartDetails>,
active: &mut HashMap<String, ActiveTask>,
join_set: &mut JoinSet<Result<PartEvent, OdlError>>,
) -> Result<(), OdlError> {
if self.ctx.live.max_connections() == 0 {
return Ok(());
}
self.ensure_pending_pool(pending, active).await?;
if !self.rampup.enabled {
while active.len() < self.ctx.live.max_connections() {
let Some(part) = pending.pop_front() else {
return Ok(());
};
self.schedule_part(part, active, join_set, None).await?;
}
return Ok(());
}
let batch_size = if self.ramp_armed.load(std::sync::atomic::Ordering::Relaxed) {
self.rampup.batch_size.max(1)
} else {
1
};
loop {
let cap = self.ctx.live.max_connections();
if cap == 0 || active.len() >= cap {
return Ok(());
}
let mut probes: Vec<Arc<Notify>> = Vec::new();
let mut opened_in_batch: u64 = 0;
while opened_in_batch < batch_size && active.len() < cap {
let Some(part) = pending.pop_front() else {
break;
};
let probe = Arc::new(Notify::new());
self.schedule_part(part, active, join_set, Some(probe.clone()))
.await?;
probes.push(probe);
opened_in_batch += 1;
}
if probes.is_empty() {
return Ok(());
}
let probes_for_task = probes.clone();
let (tx, mut rx) = tokio::sync::oneshot::channel::<()>();
tokio::spawn(async move {
for p in probes_for_task.iter() {
p.notified().await;
}
let _ = tx.send(());
});
let mut batch_ok = false;
loop {
tokio::select! {
_ = &mut rx => {
batch_ok = true;
break;
}
res = join_set.join_next() => {
let Some(result) = res else {
return Ok(());
};
let is_failure = matches!(&result, Ok(Ok(PartEvent::Failed { .. })));
self.handle_join_result_item(result, pending, active, join_set).await?;
if is_failure {
break;
}
}
_ = self.ctx.cancel.cancelled() => {
return Ok(());
}
}
}
if !batch_ok {
self.ramp_armed
.store(false, std::sync::atomic::Ordering::Relaxed);
return Ok(());
}
if pending.is_empty() || active.len() >= self.ctx.live.max_connections() {
return Ok(());
}
let delay = sample_rampup_delay(self.rampup.delay_min, self.rampup.delay_max);
if delay.is_zero() {
continue;
}
tokio::select! {
_ = tokio::time::sleep(delay) => {}
_ = self.ctx.cancel.cancelled() => return Ok(()),
}
}
}
fn apply_live_cap(&self, active: &mut HashMap<String, ActiveTask>) {
let cap = self.ctx.live.max_connections();
if cap == 0 || active.len() <= cap {
return;
}
let surplus = active.len() - cap;
let victims: Vec<String> = active.keys().take(surplus).cloned().collect();
for ulid in victims {
if let Some(task) = active.get(&ulid) {
task.cancel.cancel();
}
}
}
async fn handle_join_result_item(
&self,
res: Result<Result<PartEvent, OdlError>, tokio::task::JoinError>,
pending: &mut VecDeque<PartDetails>,
active: &mut HashMap<String, ActiveTask>,
join_set: &mut JoinSet<Result<PartEvent, OdlError>>,
) -> Result<(), OdlError> {
match res {
Ok(Ok(event)) => match event {
PartEvent::Completed(outcome) => {
active.remove(&outcome.ulid);
self.active_parts.lock().unwrap().remove(&outcome.ulid);
self.mark_part_finished(&outcome).await?;
}
PartEvent::NeedsReschedule { ulid } => {
if let Some(task) = active.remove(&ulid) {
self.active_parts.lock().unwrap().remove(&ulid);
pending.push_back(task.details);
}
}
PartEvent::Failed { ulid, attempts } => {
self.active_parts.lock().unwrap().remove(&ulid);
if let Some(task) = active.remove(&ulid) {
if pending.is_empty() && active.is_empty() {
join_set.shutdown().await;
return Err(OdlError::Other {
message: format!(
"All parts failed; last part {} failed after {} attempts",
ulid, attempts
),
origin: Box::new(std::io::Error::other("all parts failed")),
});
} else {
pending.push_back(task.details);
self.ctx.live.shrink_by_one();
}
} else {
if pending.is_empty() && active.is_empty() {
join_set.shutdown().await;
return Err(OdlError::Other {
message: format!(
"All parts failed; last part {} failed after {} attempts",
ulid, attempts
),
origin: Box::new(std::io::Error::other("all parts failed")),
});
}
}
}
},
Ok(Err(e)) => {
join_set.shutdown().await;
return Err(e);
}
Err(join_err) => {
join_set.shutdown().await;
return Err(OdlError::Other {
message: "Download task panicked".to_string(),
origin: Box::new(join_err),
});
}
}
Ok(())
}
async fn ensure_pending_pool(
&self,
pending: &mut VecDeque<PartDetails>,
active: &mut HashMap<String, ActiveTask>,
) -> Result<(), OdlError> {
let spare_capacity = self.ctx.live.max_connections().saturating_sub(active.len());
if !self.dynamic_split {
return Ok(());
}
while pending.len() < spare_capacity {
if !self.try_split_active(active, pending).await? {
break;
}
}
Ok(())
}
async fn schedule_part(
&self,
part: PartDetails,
active: &mut HashMap<String, ActiveTask>,
join_set: &mut JoinSet<Result<PartEvent, OdlError>>,
probe_notify: Option<Arc<Notify>>,
) -> Result<(), OdlError> {
let initial_downloaded = self.detect_existing_size(&part).await?;
self.ctx.emit(ProgressEvent::PartAdded {
ulid: part.ulid.clone(),
offset: part.offset,
size: part.size,
});
let controller = Arc::new(PartController::new(part.size, initial_downloaded));
let task_part = part.clone();
let controller_clone = Arc::clone(&controller);
let client = Arc::clone(&self.client);
let instruction = Arc::clone(&self.instruction);
let randomize_user_agent = self.randomize_user_agent;
let speed_limiter = self.speed_limiter.clone();
let span_ulid = task_part.ulid.clone();
let part_span = info_span!("part", ulid = span_ulid.as_str());
let ctx = self.ctx.clone();
let tracker = Arc::clone(&self.tracker);
let retry_policy = self.retry_policy;
let probe_for_task = probe_notify.clone();
let task_cancel = CancellationToken::new();
let task_cancel_for_task = task_cancel.clone();
join_set.spawn(
async move {
download_part(
client,
instruction,
task_part,
controller_clone,
randomize_user_agent,
speed_limiter,
probe_for_task,
retry_policy,
ctx,
tracker,
task_cancel_for_task,
)
.await
}
.instrument(part_span),
);
self.active_parts
.lock()
.unwrap()
.insert(part.ulid.clone(), Arc::clone(&controller));
active.insert(
part.ulid.clone(),
ActiveTask {
details: part,
controller,
cancel: task_cancel,
},
);
Ok(())
}
async fn detect_existing_size(&self, part: &PartDetails) -> Result<u64, OdlError> {
let part_path = self.instruction.part_path(&part.ulid);
match fs::metadata(&part_path).await {
Ok(meta) => Ok(meta.len()),
Err(e) => {
if e.kind() == std::io::ErrorKind::NotFound {
Ok(0)
} else {
Err(OdlError::StdIoError {
e,
extra_info: Some(format!(
"Failed to inspect download part at {}",
part_path.display(),
)),
})
}
}
}
}
async fn try_split_active(
&self,
active: &mut HashMap<String, ActiveTask>,
pending: &mut VecDeque<PartDetails>,
) -> Result<bool, OdlError> {
let candidate = active
.iter()
.filter(|(_, task)| task.details.size != crate::download::Download::UNKNOWN_PART_SIZE)
.filter(|(_, task)| task.remaining_bytes() >= MIN_DYNAMIC_SPLIT_SIZE * 2)
.max_by_key(|(_, task)| task.remaining_bytes())
.map(|(ulid, task)| SplitCandidate {
ulid: ulid.clone(),
controller: Arc::clone(&task.controller),
});
let Some(candidate) = candidate else {
return Ok(false);
};
let split_result = self.split_task(&candidate).await?;
if let Some((new_part, new_limit)) = split_result {
if let Some(task) = active.get_mut(&candidate.ulid) {
task.details.size = new_limit;
}
pending.push_back(new_part);
return Ok(true);
}
Ok(false)
}
async fn split_task(
&self,
candidate: &SplitCandidate,
) -> Result<Option<(PartDetails, u64)>, OdlError> {
if self.tracker.elapsed() <= MIN_DYNAMIC_SPLIT_ELAPSED
|| self.tracker.eta() <= MIN_DYNAMIC_SPLIT_ETA
{
return Ok(None);
}
let downloaded = candidate.controller.downloaded();
let current_limit = candidate.controller.limit();
let split =
match Download::compute_split(0, current_limit, downloaded, MIN_DYNAMIC_SPLIT_SIZE) {
Some(s) => s,
None => return Ok(None),
};
candidate.controller.set_limit(split.new_left_size);
let (new_part, encoded_metadata) = {
let mut metadata = self.metadata.lock().await;
let part_entry = metadata.parts.get_mut(&candidate.ulid).ok_or_else(|| {
OdlError::MetadataError(MetadataError::Other {
message: format!("Part with ulid {} not found", candidate.ulid),
})
})?;
let new_part_offset = part_entry.offset + split.new_left_size;
part_entry.size = split.new_left_size;
let new_ulid = Ulid::new().to_string();
let new_part = PartDetails {
offset: new_part_offset,
size: split.new_right_size,
ulid: new_ulid.clone(),
finished: false,
};
metadata.parts.insert(new_ulid, new_part.clone());
let encoded = metadata.encode_length_delimited_to_vec();
(new_part, encoded)
};
self.persist_metadata_bytes(encoded_metadata).await?;
self.ctx.emit(ProgressEvent::PartAdded {
ulid: new_part.ulid.clone(),
offset: new_part.offset,
size: new_part.size,
});
Ok(Some((new_part, split.new_left_size)))
}
async fn mark_part_finished(&self, outcome: &PartOutcome) -> Result<(), OdlError> {
let maybe_encoded = {
let mut metadata = self.metadata.lock().await;
if let Some(part) = metadata.parts.get_mut(&outcome.ulid) {
part.finished = true;
part.size = outcome.final_size;
Some(metadata.encode_length_delimited_to_vec())
} else {
None
}
};
if let Some(encoded) = maybe_encoded {
self.persist_metadata_bytes(encoded).await?;
}
Ok(())
}
async fn persist_metadata_bytes(&self, encoded: Vec<u8>) -> Result<(), OdlError> {
let _guard = self.persist_mutex.lock().await;
persist_encoded_metadata(&encoded, &self.instruction)
.await
.map_err(|e| OdlError::StdIoError {
e,
extra_info: Some(format!(
"Failed to persist metadata at {}",
self.instruction.metadata_path().display()
)),
})
}
}
struct ActiveTask {
details: PartDetails,
controller: Arc<PartController>,
cancel: CancellationToken,
}
struct SplitCandidate {
ulid: String,
controller: Arc<PartController>,
}
impl ActiveTask {
fn remaining_bytes(&self) -> u64 {
self.controller
.limit()
.saturating_sub(self.controller.downloaded())
}
}
struct PartController {
downloaded: AtomicU64,
limit: AtomicU64,
}
impl PartController {
fn new(limit: u64, initial_downloaded: u64) -> Self {
Self {
downloaded: AtomicU64::new(initial_downloaded),
limit: AtomicU64::new(limit),
}
}
fn record_progress(&self, delta: u64) -> u64 {
self.downloaded.fetch_add(delta, Ordering::SeqCst) + delta
}
fn downloaded(&self) -> u64 {
self.downloaded.load(Ordering::SeqCst)
}
fn limit(&self) -> u64 {
self.limit.load(Ordering::SeqCst)
}
fn set_limit(&self, new_limit: u64) {
self.limit.store(new_limit, Ordering::SeqCst);
}
}
struct PartOutcome {
ulid: String,
final_size: u64,
}
enum PartEvent {
Completed(PartOutcome),
NeedsReschedule { ulid: String },
Failed { ulid: String, attempts: u32 },
}
struct BandwidthLimiter {
rate: f64,
state: std::sync::Mutex<LimiterState>,
seq: AtomicU64,
}
struct LimiterState {
available: f64,
last_refill: Instant,
queue: VecDeque<u64>,
}
struct QueueGuard<'a> {
limiter: &'a BandwidthLimiter,
seq: u64,
consumed: bool,
}
impl Drop for QueueGuard<'_> {
fn drop(&mut self) {
if !self.consumed
&& let Ok(mut state) = self.limiter.state.lock()
{
state.queue.retain(|&s| s != self.seq);
}
}
}
impl BandwidthLimiter {
fn new(bytes_per_second: u64) -> Self {
let rate = bytes_per_second.max(1) as f64;
Self {
rate,
state: std::sync::Mutex::new(LimiterState {
available: rate,
last_refill: Instant::now(),
queue: VecDeque::new(),
}),
seq: AtomicU64::new(1),
}
}
async fn acquire(&self, amount: u64) {
let chunk_cap = self.rate as u64;
let mut remaining = amount;
while remaining > 0 {
let take = remaining.min(chunk_cap);
self.acquire_one(take).await;
remaining -= take;
}
}
async fn acquire_one(&self, amount: u64) {
let amount_f = amount as f64;
let my_seq = self.seq.fetch_add(1, Ordering::SeqCst);
{
let mut state = self.state.lock().expect("limiter mutex poisoned");
state.queue.push_back(my_seq);
}
let mut guard = QueueGuard {
limiter: self,
seq: my_seq,
consumed: false,
};
loop {
let sleep_duration = {
let mut state = self.state.lock().expect("limiter mutex poisoned");
state.refill(self.rate);
if let Some(&front) = state.queue.front()
&& front == my_seq
&& state.available >= amount_f
{
state.available -= amount_f;
state.queue.pop_front();
guard.consumed = true;
return;
}
if state.available < amount_f {
let deficit = amount_f - state.available;
let wait_secs = deficit / self.rate;
match Duration::try_from_secs_f64(wait_secs) {
Ok(d) => Some(d.max(Duration::from_millis(1))),
Err(_) => Some(Duration::from_millis(1)),
}
} else {
None
}
};
if let Some(dur) = sleep_duration {
time::sleep(dur).await;
} else {
tokio::task::yield_now().await;
}
}
}
}
impl LimiterState {
fn refill(&mut self, rate: f64) {
let now = Instant::now();
let elapsed = now - self.last_refill;
self.last_refill = now;
let replenished = elapsed.as_secs_f64() * rate;
self.available = (self.available + replenished).min(rate);
}
}
#[cfg(windows)]
const PART_WRITER_BUF_SIZE: usize = 1024 * 1024;
#[cfg(not(windows))]
const PART_WRITER_BUF_SIZE: usize = 256 * 1024;
const PART_WRITER_CHANNEL_CAP: usize = 64;
struct PartFileWriter {
tx: Option<mpsc::Sender<Bytes>>,
handle: Option<tokio::task::JoinHandle<std::io::Result<()>>>,
}
impl PartFileWriter {
async fn open(part_path: std::path::PathBuf) -> std::io::Result<Self> {
let file = tokio::task::spawn_blocking(move || -> std::io::Result<std::fs::File> {
use std::io::Seek;
let mut f = std::fs::OpenOptions::new()
.create(true)
.write(true)
.truncate(false)
.open(&part_path)?;
f.seek(std::io::SeekFrom::End(0))?;
Ok(f)
})
.await
.map_err(|e| std::io::Error::other(e.to_string()))??;
let (tx, mut rx) = mpsc::channel::<Bytes>(PART_WRITER_CHANNEL_CAP);
let handle = tokio::task::spawn_blocking(move || -> std::io::Result<()> {
use std::io::Write;
let mut writer = std::io::BufWriter::with_capacity(PART_WRITER_BUF_SIZE, file);
while let Some(chunk) = rx.blocking_recv() {
writer.write_all(&chunk)?;
}
writer.flush()?;
Ok(())
});
Ok(Self {
tx: Some(tx),
handle: Some(handle),
})
}
async fn write(&mut self, chunk: Bytes) -> std::io::Result<()> {
let tx = self
.tx
.as_ref()
.expect("PartFileWriter::write after finish");
if tx.send(chunk).await.is_err() {
return self.finish().await;
}
Ok(())
}
async fn finish(&mut self) -> std::io::Result<()> {
self.tx.take();
if let Some(h) = self.handle.take() {
match h.await {
Ok(r) => r,
Err(e) => Err(std::io::Error::other(e.to_string())),
}
} else {
Ok(())
}
}
}
impl Drop for PartFileWriter {
fn drop(&mut self) {
self.tx.take();
self.handle.take();
}
}
#[allow(clippy::too_many_arguments)]
async fn download_part(
client: Arc<Client>,
instruction: Arc<Download>,
part: PartDetails,
controller: Arc<PartController>,
randomize_user_agent: bool,
speed_limiter: Option<Arc<BandwidthLimiter>>,
probe_notify: Option<Arc<Notify>>,
policy: FixedThenExponentialRetry,
ctx: DownloadContext,
tracker: Arc<ProgressTracker>,
task_cancel: CancellationToken,
) -> Result<PartEvent, OdlError> {
if ctx.is_cancelled() {
return Err(OdlError::Cancelled);
}
if task_cancel.is_cancelled() {
return Ok(PartEvent::NeedsReschedule { ulid: part.ulid });
}
let PartDetails {
offset, size, ulid, ..
} = part;
let part_path = instruction.part_path(&ulid);
let url = instruction.url().clone();
let mut current_size;
let target_size = controller.limit();
let unknown_size = size == crate::download::Download::UNKNOWN_PART_SIZE;
let mut attempts: u32 = 0;
loop {
current_size = controller.downloaded();
let mut file = match PartFileWriter::open(part_path.clone()).await {
Ok(w) => w,
Err(e) => {
return Err(OdlError::StdIoError {
e,
extra_info: Some(format!("Failed to open part file {}", part_path.display())),
});
}
};
if !unknown_size && current_size >= target_size {
file.finish().await?;
ctx.emit(ProgressEvent::PartFinished { ulid: ulid.clone() });
return Ok(PartEvent::Completed(PartOutcome {
ulid,
final_size: target_size,
}));
}
let mut req = client.get(url.clone());
if !unknown_size {
let range_header = format!("bytes={}-{}", offset + current_size, offset + size - 1,);
let range_value = match HeaderValue::from_str(&range_header) {
Ok(v) => v,
Err(e) => {
let _ = file.finish().await;
return Err(OdlError::Other {
message: "Internal Error: Invalid range header".to_string(),
origin: Box::new(e),
});
}
};
req = req.header(RANGE, range_value);
req = req.header(ACCEPT_ENCODING, HeaderValue::from_static("identity"));
}
if randomize_user_agent {
req = req.header(USER_AGENT, random_user_agent())
}
let send_result = time::timeout(STALE_CONNECTION_TIMEOUT, req.send()).await;
let mut resp = match send_result {
Ok(Ok(r)) => r,
Ok(Err(_e)) => {
file.finish().await?;
match retry_sleep_or_fail_part(&policy, attempts, attempts + 1, &ctx, &ulid).await {
Ok(()) => {
attempts = attempts.saturating_add(1);
continue;
}
Err(failed) => return Ok(failed),
}
}
Err(_) => {
file.finish().await?;
match retry_sleep_or_fail_part(&policy, attempts, attempts + 1, &ctx, &ulid).await {
Ok(()) => {
attempts = attempts.saturating_add(1);
continue;
}
Err(failed) => return Ok(failed),
}
}
};
let mut started_notified = false;
let mut saw_eof = false;
loop {
let allow_until = controller.limit();
if !unknown_size && controller.downloaded() >= allow_until {
break;
}
let chunk_result = tokio::select! {
biased;
_ = ctx.cancel.cancelled() => {
let _ = file.finish().await;
return Err(OdlError::Cancelled);
}
_ = task_cancel.cancelled() => {
let _ = file.finish().await;
return Ok(PartEvent::NeedsReschedule { ulid });
}
r = time::timeout(STALE_CONNECTION_TIMEOUT, resp.chunk()) => r,
};
let maybe_chunk = match chunk_result {
Ok(chunk_res) => match chunk_res.map_err(OdlError::from) {
Ok(opt) => opt,
Err(_e) => {
file.finish().await?;
match retry_sleep_or_fail_part(&policy, attempts, attempts + 1, &ctx, &ulid)
.await
{
Ok(()) => {
attempts = attempts.saturating_add(1);
break;
}
Err(failed) => return Ok(failed),
}
}
},
Err(_) => {
file.finish().await?;
match retry_sleep_or_fail_part(&policy, attempts, attempts + 1, &ctx, &ulid)
.await
{
Ok(()) => {
attempts = attempts.saturating_add(1);
break;
}
Err(failed) => return Ok(failed),
}
}
};
let mut chunk = match maybe_chunk {
Some(chunk) => chunk,
None => {
saw_eof = true;
break;
}
};
if !started_notified {
if let Some(n) = probe_notify.as_ref() {
n.notify_one();
}
started_notified = true;
}
if !unknown_size {
let downloaded = controller.downloaded();
let remaining = allow_until.saturating_sub(downloaded);
if chunk.len() as u64 > remaining {
chunk = chunk.split_to(remaining as usize);
}
}
let len = chunk.len() as u64;
if let Some(limiter) = speed_limiter.as_ref() {
tokio::select! {
_ = limiter.acquire(len) => {}
_ = ctx.cancel.cancelled() => {
let _ = file.finish().await;
return Err(OdlError::Cancelled);
}
_ = task_cancel.cancelled() => {
let _ = file.finish().await;
return Ok(PartEvent::NeedsReschedule { ulid });
}
}
}
file.write(chunk).await?;
controller.record_progress(len);
tracker.advance(len);
if ctx.is_cancelled() {
let _ = file.finish().await;
return Err(OdlError::Cancelled);
}
}
file.finish().await?;
if unknown_size && saw_eof {
let final_size = controller.downloaded();
ctx.emit(ProgressEvent::PartFinished { ulid: ulid.clone() });
return Ok(PartEvent::Completed(PartOutcome { ulid, final_size }));
}
if controller.downloaded() >= controller.limit() {
ctx.emit(ProgressEvent::PartFinished { ulid: ulid.clone() });
return Ok(PartEvent::Completed(PartOutcome {
ulid,
final_size: controller.limit(),
}));
}
if saw_eof {
return Ok(PartEvent::NeedsReschedule { ulid });
}
attempts = attempts.saturating_add(1);
match retry_sleep_or_fail_part(&policy, attempts, attempts, &ctx, &ulid).await {
Ok(()) => continue,
Err(failed) => return Ok(failed),
}
}
}
async fn retry_sleep_or_fail_part(
policy: &FixedThenExponentialRetry,
_attempts_for_policy: u32,
attempts_display: u32,
ctx: &DownloadContext,
ulid: &str,
) -> Result<(), PartEvent> {
ctx.emit(ProgressEvent::PartRetrying {
ulid: ulid.to_string(),
attempt: attempts_display,
});
if wait_for_retry(policy, attempts_display, ctx).await {
Ok(())
} else {
Err(PartEvent::Failed {
ulid: ulid.to_string(),
attempts: attempts_display,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::download::DownloadBuilder;
use futures::FutureExt;
use mockito::{Matcher, Server};
use reqwest::Url;
use tempfile::tempdir;
use tokio::{fs, time};
const TEST_FILENAME: &str = "test.bin";
fn make_part(ulid: &str, offset: u64, size: u64) -> PartDetails {
PartDetails {
offset,
size,
ulid: ulid.to_string(),
finished: false,
}
}
async fn create_instruction(
download_dir: &std::path::Path,
save_dir: &std::path::Path,
url: &str,
size: u64,
parts: HashMap<String, PartDetails>,
max_connections: u64,
) -> Arc<Download> {
let download = DownloadBuilder::default()
.download_dir(download_dir.to_path_buf())
.save_dir(save_dir.to_path_buf())
.filename(TEST_FILENAME.to_string())
.url(Url::parse(url).expect("valid url"))
.size(Some(size))
.parts(parts)
.max_connections(max_connections)
.is_resumable(true)
.build()
.expect("build download");
Arc::new(download)
}
async fn read_metadata(instruction: &Download) -> DownloadMetadata {
let bytes = fs::read(instruction.metadata_path())
.await
.expect("metadata file present");
DownloadMetadata::decode_length_delimited(&*bytes).expect("decode metadata")
}
#[tokio::test]
async fn test_downloader_downloads_single_part() -> Result<(), Box<dyn std::error::Error>> {
let file_content = b"HelloDownloader";
let mut server = Server::new_async().await;
let base = server.url();
let get_mock = server
.mock("GET", "/file")
.match_header(
"range",
Matcher::Exact(format!("bytes=0-{}", file_content.len() - 1)),
)
.with_status(206)
.with_body(file_content)
.create_async()
.await;
let tmp = tempdir()?;
let download_dir = tmp.path().join("download");
let save_dir = tmp.path().join("save");
fs::create_dir_all(&download_dir).await?;
fs::create_dir_all(&save_dir).await?;
let mut parts = HashMap::new();
parts.insert(
"part1".to_string(),
make_part("part1", 0, file_content.len() as u64),
);
let instruction = create_instruction(
&download_dir,
&save_dir,
&format!("{}/file", base),
file_content.len() as u64,
parts,
1,
)
.await;
let metadata = instruction.as_metadata();
let downloader = Downloader::new(
Arc::clone(&instruction),
metadata,
reqwest::Client::builder().build()?,
false,
None,
true,
RampupConfig::disabled(),
FixedThenExponentialRetry::default(),
DownloadContext::new(),
);
let updated_metadata = downloader.run().await?;
let part_bytes = fs::read(instruction.part_path("part1")).await?;
assert_eq!(part_bytes, file_content);
assert!(
updated_metadata
.parts
.get("part1")
.map(|p| p.finished)
.unwrap_or(false)
);
assert!(fs::try_exists(instruction.metadata_path()).await?);
get_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_downloader_streams_unknown_size_until_eof()
-> Result<(), Box<dyn std::error::Error>> {
let file_content = b"<html><body>hello world</body></html>";
let mut server = Server::new_async().await;
let base = server.url();
let get_mock = server
.mock("GET", "/page")
.match_header("range", Matcher::Missing)
.with_status(200)
.with_body(file_content)
.create_async()
.await;
let tmp = tempdir()?;
let download_dir = tmp.path().join("download");
let save_dir = tmp.path().join("save");
fs::create_dir_all(&download_dir).await?;
fs::create_dir_all(&save_dir).await?;
let mut parts = HashMap::new();
parts.insert(
"part1".to_string(),
make_part("part1", 0, crate::download::Download::UNKNOWN_PART_SIZE),
);
let download = DownloadBuilder::default()
.download_dir(download_dir.clone())
.save_dir(save_dir.clone())
.filename(TEST_FILENAME.to_string())
.url(Url::parse(&format!("{}/page", base))?)
.size(None)
.parts(parts)
.max_connections(1)
.is_resumable(false)
.build()?;
let instruction = Arc::new(download);
let metadata = instruction.as_metadata();
let downloader = Downloader::new(
Arc::clone(&instruction),
metadata,
reqwest::Client::builder().build()?,
false,
None,
true,
RampupConfig::disabled(),
FixedThenExponentialRetry::default(),
DownloadContext::new(),
);
let updated_metadata = downloader.run().await?;
let part_bytes = fs::read(instruction.part_path("part1")).await?;
assert_eq!(part_bytes, file_content);
let part = updated_metadata
.parts
.get("part1")
.expect("part1 present after run");
assert!(part.finished);
assert_eq!(part.size, file_content.len() as u64);
get_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_apply_live_cap_cancels_surplus() -> Result<(), Box<dyn std::error::Error>> {
let tmp = tempdir()?;
let download_dir = tmp.path().join("download");
let save_dir = tmp.path().join("save");
fs::create_dir_all(&download_dir).await?;
fs::create_dir_all(&save_dir).await?;
let mut parts = HashMap::new();
parts.insert("p1".to_string(), make_part("p1", 0, 1024));
let instruction = create_instruction(
&download_dir,
&save_dir,
"http://example.com/file",
1024,
parts,
3,
)
.await;
let metadata = instruction.as_metadata();
let downloader = Downloader::new(
Arc::clone(&instruction),
metadata,
reqwest::Client::builder().build()?,
false,
None,
true,
RampupConfig::disabled(),
FixedThenExponentialRetry::default(),
DownloadContext::new(),
);
let make_task = |size: u64| ActiveTask {
details: make_part("x", 0, size),
controller: Arc::new(PartController::new(size, 0)),
cancel: CancellationToken::new(),
};
let mut active: HashMap<String, ActiveTask> = HashMap::new();
active.insert("a".to_string(), make_task(1024));
active.insert("b".to_string(), make_task(1024));
active.insert("c".to_string(), make_task(1024));
downloader.ctx.live.set_max_connections(3);
downloader.apply_live_cap(&mut active);
assert_eq!(
active.values().filter(|t| t.cancel.is_cancelled()).count(),
0
);
downloader.ctx.live.set_max_connections(1);
downloader.apply_live_cap(&mut active);
assert_eq!(
active.values().filter(|t| t.cancel.is_cancelled()).count(),
2
);
Ok(())
}
#[tokio::test]
async fn test_live_controls_seed_and_set() {
let ctx = DownloadContext::new();
assert_eq!(ctx.live.max_connections(), 0);
ctx.live.seed_if_unset(4);
assert_eq!(ctx.live.max_connections(), 4);
ctx.live.seed_if_unset(8);
assert_eq!(ctx.live.max_connections(), 4);
ctx.live.set_max_connections(0);
assert_eq!(ctx.live.max_connections(), 1);
ctx.live.set_max_connections(6);
assert_eq!(ctx.live.max_connections(), 6);
}
#[tokio::test]
async fn test_downloader_split_persists_metadata() -> Result<(), Box<dyn std::error::Error>> {
let tmp = tempdir()?;
let download_dir = tmp.path().join("download");
let save_dir = tmp.path().join("save");
fs::create_dir_all(&download_dir).await?;
fs::create_dir_all(&save_dir).await?;
let mut parts = HashMap::new();
let original_size = MIN_DYNAMIC_SPLIT_SIZE * 4;
parts.insert("orig".to_string(), make_part("orig", 0, original_size));
let instruction = create_instruction(
&download_dir,
&save_dir,
"http://example.com/file",
original_size,
parts,
2,
)
.await;
let metadata = instruction.as_metadata();
let downloader = Downloader::new(
Arc::clone(&instruction),
metadata,
reqwest::Client::builder().build()?,
false,
None,
true,
RampupConfig::disabled(),
FixedThenExponentialRetry::default(),
DownloadContext::new(),
);
downloader.tracker.set_total(Some(120_000));
downloader.tracker.advance(1);
time::sleep(Duration::from_millis(100)).await;
assert!(downloader.tracker.eta() > MIN_DYNAMIC_SPLIT_ETA);
let controller = Arc::new(PartController::new(original_size, 0));
let candidate = SplitCandidate {
ulid: "orig".to_string(),
controller: Arc::clone(&controller),
};
let split_result = downloader.split_task(&candidate).await?;
assert!(split_result.is_some());
let persisted = read_metadata(&instruction).await;
assert_eq!(persisted.parts.len(), 2);
assert!(persisted.parts.values().any(|p| p.ulid != "orig"));
Ok(())
}
#[tokio::test]
async fn test_downloader_mark_part_finished_persists() -> Result<(), Box<dyn std::error::Error>>
{
let tmp = tempdir()?;
let download_dir = tmp.path().join("download");
let save_dir = tmp.path().join("save");
fs::create_dir_all(&download_dir).await?;
fs::create_dir_all(&save_dir).await?;
let mut parts = HashMap::new();
parts.insert("p1".to_string(), make_part("p1", 0, 1024));
let instruction = create_instruction(
&download_dir,
&save_dir,
"http://example.com/file",
1024,
parts,
1,
)
.await;
let metadata = instruction.as_metadata();
let downloader = Downloader::new(
Arc::clone(&instruction),
metadata,
reqwest::Client::builder().build()?,
false,
None,
true,
RampupConfig::disabled(),
FixedThenExponentialRetry::default(),
DownloadContext::new(),
);
let outcome = PartOutcome {
ulid: "p1".to_string(),
final_size: 1024,
};
downloader.mark_part_finished(&outcome).await?;
let persisted = read_metadata(&instruction).await;
let part = persisted.parts.get("p1").expect("part exists");
assert!(part.finished);
assert_eq!(part.size, 1024);
Ok(())
}
#[tokio::test]
async fn test_download_part_returns_reschedule_on_short_body()
-> Result<(), Box<dyn std::error::Error>> {
let tmp = tempdir()?;
let download_dir = tmp.path().join("download");
let save_dir = tmp.path().join("save");
fs::create_dir_all(&download_dir).await?;
fs::create_dir_all(&save_dir).await?;
let mut server = Server::new_async().await;
let base = server.url();
let file_content = b"12"; let get_mock = server
.mock("GET", "/partial")
.match_header("range", Matcher::Exact("bytes=0-4".into()))
.with_status(206)
.with_body(file_content)
.create_async()
.await;
let mut parts = HashMap::new();
parts.insert("part".to_string(), make_part("part", 0, 5));
let instruction = create_instruction(
&download_dir,
&save_dir,
&format!("{}/partial", base),
5,
parts,
1,
)
.await;
let metadata = instruction.as_metadata();
let part = metadata.parts.get("part").unwrap().clone();
let controller = Arc::new(PartController::new(part.size, 0));
let event = download_part(
Arc::new(reqwest::Client::builder().build()?),
Arc::clone(&instruction),
part,
controller,
false,
None,
None,
FixedThenExponentialRetry::default(),
DownloadContext::new(),
Arc::new(ProgressTracker::new(Some(5))),
CancellationToken::new(),
)
.await?;
match event {
PartEvent::NeedsReschedule { ulid } => assert_eq!(ulid, "part"),
PartEvent::Completed(_) => panic!("expected reschedule"),
PartEvent::Failed { ulid, attempts } => panic!(
"unexpected failed part {} after {} attempts",
ulid, attempts
),
}
get_mock.assert_async().await;
Ok(())
}
#[tokio::test(start_paused = true)]
async fn test_bandwidth_limiter_enforces_limit() {
let limiter = BandwidthLimiter::new(1024);
limiter.acquire(1024).await;
let second = limiter.acquire(1024);
tokio::pin!(second);
assert!(second.as_mut().now_or_never().is_none());
time::advance(Duration::from_millis(900)).await;
assert!(second.as_mut().now_or_never().is_none());
time::advance(Duration::from_millis(200)).await;
assert!(second.as_mut().now_or_never().is_some());
}
#[tokio::test(start_paused = true)]
async fn test_bandwidth_limiter_dropped_acquire_does_not_block_queue() {
let limiter = BandwidthLimiter::new(1024);
limiter.acquire(1024).await;
{
let pending = limiter.acquire(1024);
tokio::pin!(pending);
assert!(pending.as_mut().now_or_never().is_none());
}
time::advance(Duration::from_millis(1100)).await;
let third = limiter.acquire(1024);
tokio::pin!(third);
assert!(third.as_mut().now_or_never().is_some());
}
#[tokio::test]
async fn test_bandwidth_limiter_handles_amount_larger_than_rate() {
let limiter = Arc::new(BandwidthLimiter::new(8192));
tokio::time::timeout(Duration::from_secs(10), limiter.acquire(32 * 1024))
.await
.expect("acquire must not deadlock for amount > rate");
}
#[test]
fn sample_rampup_delay_clamps_when_max_le_min() {
let min = Duration::from_millis(500);
let max = Duration::from_millis(200);
assert_eq!(sample_rampup_delay(min, max), min);
assert_eq!(sample_rampup_delay(min, min), min);
}
#[test]
fn sample_rampup_delay_stays_within_bounds() {
let min = Duration::from_millis(500);
let max = Duration::from_millis(1000);
for _ in 0..2000 {
let d = sample_rampup_delay(min, max);
assert!(d >= min && d <= max, "delay {:?} out of bounds", d);
}
}
async fn spawn_hanging_http_server() -> (
std::net::SocketAddr,
Arc<std::sync::atomic::AtomicUsize>,
tokio::task::JoinHandle<()>,
) {
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let addr = listener.local_addr().expect("local addr");
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let handle = tokio::spawn(async move {
loop {
let Ok((mut sock, _)) = listener.accept().await else {
return;
};
counter_clone.fetch_add(1, Ordering::SeqCst);
tokio::spawn(async move {
let mut buf = [0u8; 4096];
let mut acc = Vec::new();
loop {
let n = match sock.read(&mut buf).await {
Ok(0) | Err(_) => return,
Ok(n) => n,
};
acc.extend_from_slice(&buf[..n]);
if acc.windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
}
let req = String::from_utf8_lossy(&acc);
let (start, end) = req
.lines()
.find_map(|l| {
let l = l.trim();
let rest = l.strip_prefix("Range:")?.trim();
let rest = rest.strip_prefix("bytes=")?;
let mut it = rest.split('-');
let s: u64 = it.next()?.trim().parse().ok()?;
let e: u64 = it.next()?.trim().parse().ok()?;
Some((s, e))
})
.unwrap_or((0, 0));
let _ = (start, end);
let header = "HTTP/1.1 206 Partial Content\r\nTransfer-Encoding: chunked\r\nAccept-Ranges: bytes\r\nConnection: keep-alive\r\n\r\n";
let mut out = header.as_bytes().to_vec();
out.extend_from_slice(b"1\r\n\x00\r\n");
if sock.write_all(&out).await.is_err() {
return;
}
let _ = sock.flush().await;
std::future::pending::<()>().await;
drop(sock);
});
}
});
(addr, counter, handle)
}
async fn wait_for<F>(label: &str, timeout: Duration, mut pred: F)
where
F: FnMut() -> bool,
{
let deadline = tokio::time::Instant::now() + timeout;
loop {
if pred() {
return;
}
if tokio::time::Instant::now() >= deadline {
panic!("timed out after {:?}: {}", timeout, label);
}
tokio::time::sleep(Duration::from_millis(5)).await;
}
}
async fn build_rampup_test_downloader(
addr: std::net::SocketAddr,
n_parts: u64,
rampup: RampupConfig,
) -> (Arc<Download>, Downloader, tempfile::TempDir) {
let tmp = tempdir().expect("tmp");
let download_dir = tmp.path().join("download");
let save_dir = tmp.path().join("save");
fs::create_dir_all(&download_dir).await.expect("mkdir dl");
fs::create_dir_all(&save_dir).await.expect("mkdir save");
let part_size: u64 = 1024 * 1024;
let total = part_size * n_parts;
let mut parts = HashMap::new();
for i in 0..n_parts {
let ulid = format!("p{i}");
parts.insert(ulid.clone(), make_part(&ulid, i * part_size, part_size));
}
let url = format!("http://{}/file", addr);
let instruction =
create_instruction(&download_dir, &save_dir, &url, total, parts, n_parts).await;
let metadata = instruction.as_metadata();
let downloader = Downloader::new(
Arc::clone(&instruction),
metadata,
reqwest::Client::builder().build().expect("client"),
false,
None,
false, rampup,
FixedThenExponentialRetry::default(),
DownloadContext::new(),
);
(instruction, downloader, tmp)
}
#[tokio::test(flavor = "multi_thread")]
async fn fill_capacity_ramps_connections_in_batches() {
use std::sync::atomic::Ordering;
const DELAY: Duration = Duration::from_millis(300);
let (addr, counter, server_task) = spawn_hanging_http_server().await;
let (_instruction, downloader, _tmp) = build_rampup_test_downloader(
addr,
7,
RampupConfig {
enabled: true,
batch_size: 2,
delay_min: DELAY,
delay_max: DELAY,
},
)
.await;
let cancel = downloader.ctx.cancel.clone();
let dl_task = tokio::spawn(async move {
let _ = downloader.run().await;
});
wait_for(
"counter >= 3 (probe + first batch)",
Duration::from_secs(5),
|| counter.load(Ordering::SeqCst) >= 3,
)
.await;
tokio::time::sleep(DELAY / 3).await;
assert_eq!(
counter.load(Ordering::SeqCst),
3,
"rampup must wait for inter-batch delay before opening more"
);
wait_for("counter >= 5 (second batch)", DELAY * 3, || {
counter.load(Ordering::SeqCst) >= 5
})
.await;
tokio::time::sleep(DELAY / 3).await;
assert_eq!(counter.load(Ordering::SeqCst), 5);
wait_for("counter == 7 (cap reached)", DELAY * 3, || {
counter.load(Ordering::SeqCst) >= 7
})
.await;
assert_eq!(counter.load(Ordering::SeqCst), 7);
tokio::time::sleep(DELAY * 2).await;
assert_eq!(counter.load(Ordering::SeqCst), 7);
cancel.cancel();
let _ = dl_task.await;
server_task.abort();
}
async fn spawn_drop_server() -> (
std::net::SocketAddr,
Arc<std::sync::atomic::AtomicUsize>,
tokio::task::JoinHandle<()>,
) {
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let addr = listener.local_addr().expect("local addr");
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let handle = tokio::spawn(async move {
loop {
let Ok((sock, _)) = listener.accept().await else {
return;
};
counter_clone.fetch_add(1, Ordering::SeqCst);
drop(sock);
}
});
(addr, counter, handle)
}
#[tokio::test(flavor = "multi_thread")]
async fn fill_capacity_aborts_ramp_when_batch_part_fails() {
use std::sync::atomic::Ordering;
let (addr, counter, server_task) = spawn_drop_server().await;
let tmp = tempdir().expect("tmp");
let download_dir = tmp.path().join("download");
let save_dir = tmp.path().join("save");
fs::create_dir_all(&download_dir).await.expect("mkdir dl");
fs::create_dir_all(&save_dir).await.expect("mkdir save");
let n_parts: u64 = 10;
let part_size: u64 = 1024 * 1024;
let total = part_size * n_parts;
let mut parts = HashMap::new();
for i in 0..n_parts {
let ulid = format!("p{i}");
parts.insert(ulid.clone(), make_part(&ulid, i * part_size, part_size));
}
let url = format!("http://{}/file", addr);
let instruction =
create_instruction(&download_dir, &save_dir, &url, total, parts, n_parts).await;
let metadata = instruction.as_metadata();
let downloader = Downloader::new(
Arc::clone(&instruction),
metadata,
reqwest::Client::builder().build().expect("client"),
false,
None,
false,
RampupConfig {
enabled: true,
batch_size: 2,
delay_min: Duration::from_millis(50),
delay_max: Duration::from_millis(50),
},
FixedThenExponentialRetry {
max_n_retries: 1,
wait_time: Duration::from_millis(20),
n_fixed_retries: 1,
},
DownloadContext::new(),
);
let cancel = downloader.ctx.cancel.clone();
let dl_task = tokio::spawn(async move {
let _ = downloader.run().await;
});
tokio::time::sleep(Duration::from_millis(800)).await;
cancel.cancel();
let _ = dl_task.await;
let opened = counter.load(Ordering::SeqCst);
assert!(
opened < 30,
"rampup did not throttle on failures: {} connections opened",
opened
);
server_task.abort();
}
#[tokio::test(flavor = "multi_thread")]
async fn fill_capacity_no_rampup_opens_all_at_once() {
use std::sync::atomic::Ordering;
let (addr, counter, server_task) = spawn_hanging_http_server().await;
let (_instruction, downloader, _tmp) =
build_rampup_test_downloader(addr, 6, RampupConfig::disabled()).await;
let cancel = downloader.ctx.cancel.clone();
let dl_task = tokio::spawn(async move {
let _ = downloader.run().await;
});
wait_for(
"counter == 6 (all open at once)",
Duration::from_secs(5),
|| counter.load(Ordering::SeqCst) >= 6,
)
.await;
assert_eq!(counter.load(Ordering::SeqCst), 6);
cancel.cancel();
let _ = dl_task.await;
server_task.abort();
}
}