mod checksum;
mod downloader;
mod io;
mod recover_metadata;
mod save_conflict;
mod server_conflict;
use std::{path::PathBuf, sync::Arc};
use fs2::FileExt;
use reqwest::{
Client, Proxy, Url,
header::{HeaderMap, USER_AGENT},
};
use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore};
#[derive(Debug)]
pub struct DownloadPermit(#[allow(dead_code)] OwnedSemaphorePermit);
use crate::config::{Config, DownloadOptions};
use crate::download_manager::checksum::check_final_file_checksum;
use crate::download_manager::recover_metadata::recover_metadata;
use crate::download_manager::{
downloader::{Downloader, RampupConfig},
io::persist_metadata,
};
use crate::download_manager::{io::assemble_final_file, server_conflict::resolve_server_conflicts};
use crate::download_manager::{io::remove_all_parts, save_conflict::resolve_save_conflicts};
use crate::error::MetadataError;
use crate::progress::{DownloadContext, Phase, ProgressEvent};
use crate::response_info::ResponseInfo;
use crate::retry_policies::{FixedThenExponentialRetry, wait_for_retry};
use crate::{
conflict::{SaveConflictResolver, ServerConflictResolver},
download_manager::io::sum_parts_on_disk,
};
use crate::{credentials::Credentials, user_agents::random_user_agent};
use crate::{download::Download, download_metadata::PartDetails, error::OdlError};
#[derive(Debug)]
pub struct DownloadManager {
config: Config,
semaphore: Arc<Semaphore>,
}
#[derive(Debug)]
pub struct EvaluateRequest<'a, CR: SaveConflictResolver> {
pub url: Url,
pub save_dir: PathBuf,
pub conflict_resolver: &'a CR,
pub credentials: Option<Credentials>,
pub ctx: Option<&'a DownloadContext>,
pub options: Option<&'a DownloadOptions>,
}
impl<'a, CR: SaveConflictResolver> EvaluateRequest<'a, CR> {
pub fn new<P: Into<PathBuf>>(url: Url, save_dir: P, conflict_resolver: &'a CR) -> Self {
Self {
url,
save_dir: save_dir.into(),
conflict_resolver,
credentials: None,
ctx: None,
options: None,
}
}
pub fn credentials(mut self, c: Credentials) -> Self {
self.credentials = Some(c);
self
}
pub fn ctx(mut self, c: &'a DownloadContext) -> Self {
self.ctx = Some(c);
self
}
pub fn options(mut self, o: &'a DownloadOptions) -> Self {
self.options = Some(o);
self
}
}
#[derive(Debug)]
pub struct DownloadRequest<'a, CR: ServerConflictResolver> {
pub instruction: Download,
pub conflict_resolver: &'a CR,
pub ctx: Option<&'a DownloadContext>,
pub options: Option<&'a DownloadOptions>,
}
impl<'a, CR: ServerConflictResolver> DownloadRequest<'a, CR> {
pub fn new(instruction: Download, conflict_resolver: &'a CR) -> Self {
Self {
instruction,
conflict_resolver,
ctx: None,
options: None,
}
}
pub fn ctx(mut self, c: &'a DownloadContext) -> Self {
self.ctx = Some(c);
self
}
pub fn options(mut self, o: &'a DownloadOptions) -> Self {
self.options = Some(o);
self
}
}
impl DownloadManager {
pub fn new(config: Config) -> DownloadManager {
let max_concurrent_downloads = config.max_concurrent_downloads();
DownloadManager {
config,
semaphore: Arc::new(Semaphore::new(max_concurrent_downloads)),
}
}
pub fn config(&self) -> &Config {
&self.config
}
pub async fn set_config(&mut self, value: Config) -> Result<(), AcquireError> {
let old_max = self.config.max_concurrent_downloads();
self.config = value;
let new_max = self.config.max_concurrent_downloads();
if new_max > old_max {
let add_count = new_max.saturating_sub(old_max);
self.semaphore.add_permits(add_count);
} else if new_max < old_max {
let forget_count = old_max.saturating_sub(new_max);
let _perm = Arc::clone(&self.semaphore)
.acquire_many_owned(forget_count as u32)
.await?;
_perm.forget();
}
Ok(())
}
pub async fn evaluate<CR>(&self, req: EvaluateRequest<'_, CR>) -> Result<Download, OdlError>
where
CR: SaveConflictResolver,
{
let EvaluateRequest {
url,
save_dir,
conflict_resolver,
credentials,
ctx,
options,
} = req;
let default_ctx;
let ctx = match ctx {
Some(c) => c,
None => {
default_ctx = DownloadContext::new();
&default_ctx
}
};
let opts = options.unwrap_or(self.config.download());
ctx.emit(ProgressEvent::PhaseChanged(Phase::Evaluating));
if ctx.is_cancelled() {
return Err(OdlError::Cancelled);
}
let client = self.get_client(opts)?;
let retry_policy = FixedThenExponentialRetry {
max_n_retries: opts.max_retries(),
wait_time: opts.wait_between_retries(),
n_fixed_retries: opts.n_fixed_retries(),
};
let mut attempts: u32 = 0;
let resp = loop {
let mut req = client
.head(url.clone())
.header(
"Want-Repr-Digest",
"sha-512=9, sha-384=8, sha-256=7, sha-1=1, md5=1",
)
.header(
"Want-Content-Digest",
"sha-512=9, sha-384=8, sha-256=7, sha-1=1, md5=1",
);
if let Some(creds) = &credentials {
req = req.basic_auth(creds.username(), creds.password());
}
if opts.user_agent().is_none() && opts.randomize_user_agent() {
req = req.header(USER_AGENT, random_user_agent());
}
match req.send().await.and_then(|r| r.error_for_status()) {
Ok(r) => break r,
Err(e) => {
attempts = attempts.saturating_add(1);
if !wait_for_retry(&retry_policy, attempts, ctx).await {
return Err(OdlError::from(e));
}
if ctx.is_cancelled() {
return Err(OdlError::Cancelled);
}
}
}
};
let info = ResponseInfo::from(resp);
let instruction = Download::from_response_info(
self.config.download_dir(),
save_dir,
info,
opts.max_connections(),
opts.use_server_time(),
credentials,
Option::<Proxy>::from(opts),
Some(HeaderMap::from(opts)),
);
ctx.emit(ProgressEvent::PhaseChanged(Phase::ResolvingConflicts));
let instruction = resolve_save_conflicts(instruction, conflict_resolver).await?;
ctx.emit(ProgressEvent::FilenameResolved(
instruction.filename().to_string(),
));
ctx.emit(ProgressEvent::Progress {
downloaded: 0,
total: instruction.size(),
});
Ok(instruction)
}
pub async fn quick_evaluate<CR>(
&self,
req: EvaluateRequest<'_, CR>,
) -> Result<Download, OdlError>
where
CR: SaveConflictResolver,
{
let EvaluateRequest {
url,
save_dir,
conflict_resolver,
credentials,
ctx,
options,
} = req;
let default_ctx;
let ctx = match ctx {
Some(c) => c,
None => {
default_ctx = DownloadContext::new();
&default_ctx
}
};
let opts = options.unwrap_or(self.config.download());
ctx.emit(ProgressEvent::PhaseChanged(Phase::Evaluating));
if ctx.is_cancelled() {
return Err(OdlError::Cancelled);
}
let info = ResponseInfo::from_url(url);
let instruction = Download::from_response_info(
self.config.download_dir(),
save_dir,
info,
opts.max_connections(),
opts.use_server_time(),
credentials,
Option::<Proxy>::from(opts),
Some(HeaderMap::from(opts)),
);
ctx.emit(ProgressEvent::PhaseChanged(Phase::ResolvingConflicts));
let instruction = resolve_save_conflicts(instruction, conflict_resolver).await?;
ctx.emit(ProgressEvent::FilenameResolved(
instruction.filename().to_string(),
));
ctx.emit(ProgressEvent::Progress {
downloaded: 0,
total: instruction.size(),
});
Ok(instruction)
}
pub async fn download<CR>(&self, req: DownloadRequest<'_, CR>) -> Result<PathBuf, OdlError>
where
CR: ServerConflictResolver,
{
let DownloadRequest {
instruction,
conflict_resolver,
ctx,
options,
} = req;
let default_ctx;
let ctx = match ctx {
Some(c) => c,
None => {
default_ctx = DownloadContext::new();
&default_ctx
}
};
let opts = options.unwrap_or(self.config.download());
let result = self
.download_inner(instruction, conflict_resolver, ctx, opts)
.await;
match &result {
Ok(_) => {}
Err(OdlError::Cancelled) => ctx.emit(ProgressEvent::Cancelled),
Err(e) => ctx.emit(ProgressEvent::Failed {
message: e.to_string(),
}),
}
result
}
async fn download_inner<CR>(
&self,
instruction: Download,
conflict_resolver: &CR,
ctx: &DownloadContext,
opts: &DownloadOptions,
) -> Result<PathBuf, OdlError>
where
CR: ServerConflictResolver,
{
tokio::fs::create_dir_all(instruction.download_dir()).await?;
let f = tokio::fs::OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(true)
.open(instruction.lockfile_path())
.await
.map_err(|e| OdlError::StdIoError {
e,
extra_info: Some(format!(
"Failed to open lockfile for exclusive locking at {}",
instruction.lockfile_path().display(),
)),
})?;
let f = f.into_std().await;
if f.try_lock_exclusive().is_err() {
return Err(OdlError::MetadataError(MetadataError::LockfileInUse));
}
let result = self
.process_download(instruction, conflict_resolver, ctx, opts)
.await;
let _ = FileExt::unlock(&f);
result
}
pub async fn acquire_download_permit(&self) -> Result<DownloadPermit, AcquireError> {
Arc::clone(&self.semaphore)
.acquire_owned()
.await
.map(DownloadPermit)
}
fn get_client(&self, opts: &DownloadOptions) -> Result<Client, OdlError> {
let mut client = reqwest::Client::builder();
if opts.headers().is_some_and(|x| !x.is_empty()) {
client = client.default_headers(HeaderMap::from(opts));
}
if let Some(proxy) = Option::<Proxy>::from(opts) {
client = client.proxy(proxy);
}
if opts.accept_invalid_certs() {
client = client.danger_accept_invalid_certs(opts.accept_invalid_certs())
}
if let Some(user_agent) = opts.user_agent() {
client = client.user_agent(user_agent.to_owned());
}
if let Some(timeout) = opts.connect_timeout() {
client = client.connect_timeout(timeout);
}
if !opts.http2() {
client = client.http1_only();
} else {
client = client.http2_adaptive_window(true);
}
Ok(client.build()?)
}
async fn process_download<CR>(
&self,
instruction: Download,
conflict_resolver: &CR,
ctx: &DownloadContext,
opts: &DownloadOptions,
) -> Result<PathBuf, OdlError>
where
CR: ServerConflictResolver,
{
if ctx.is_cancelled() {
return Err(OdlError::Cancelled);
}
tokio::fs::create_dir_all(instruction.save_dir()).await?;
recover_metadata(&instruction).await?;
ctx.emit(ProgressEvent::PhaseChanged(Phase::ResolvingConflicts));
let mut metadata = resolve_server_conflicts(&instruction, conflict_resolver).await?;
let initial_on_disk = if metadata.finished {
sum_parts_on_disk(&instruction, &metadata).await
} else {
None
};
if let Some(sum_of_parts_sizes) = initial_on_disk {
let size: Option<u64> = metadata.size.or_else(|| instruction.size());
ctx.emit(ProgressEvent::Progress {
downloaded: sum_of_parts_sizes,
total: size,
});
}
if !metadata.finished {
let final_path_recovery = instruction.final_file_path();
if !metadata.checksums.is_empty()
&& tokio::fs::try_exists(&final_path_recovery)
.await
.unwrap_or(false)
&& check_final_file_checksum(&metadata, &instruction, false)
.await
.is_ok()
{
metadata.finished = true;
persist_metadata(&metadata, &instruction).await?;
remove_all_parts(instruction.download_dir()).await;
ctx.emit(ProgressEvent::Completed {
path: final_path_recovery.clone(),
already_complete: true,
});
return Ok(final_path_recovery);
}
if opts.max_connections() > metadata.max_connections {
grow_parts(&instruction, &mut metadata, opts.max_connections()).await?;
persist_metadata(&metadata, &instruction).await?;
}
let to_download = metadata
.parts
.values()
.filter_map(|p| if !p.finished { Some(p.clone()) } else { None })
.collect::<Vec<PartDetails>>();
if !to_download.is_empty() {
let randomize_user_agent = if opts.user_agent().is_some() {
false
} else {
opts.randomize_user_agent()
};
let client = self.get_client(opts)?;
let retry_policy = crate::retry_policies::FixedThenExponentialRetry {
max_n_retries: opts.max_retries(),
wait_time: opts.wait_between_retries(),
n_fixed_retries: opts.n_fixed_retries(),
};
ctx.emit(ProgressEvent::PhaseChanged(Phase::Downloading));
let downloader = Downloader::new(
Arc::new(instruction.clone()),
metadata,
client,
randomize_user_agent,
opts.speed_limit(),
opts.dynamic_split(),
RampupConfig {
enabled: opts.rampup(),
batch_size: opts.rampup_batch_size(),
delay_min: opts.rampup_delay_min(),
delay_max: opts.rampup_delay_max(),
},
retry_policy,
ctx.clone(),
);
let mdata = downloader.run().await?;
persist_metadata(&mdata, &instruction).await?;
metadata = mdata;
}
let final_path_for_cleanup = instruction.final_file_path();
if tokio::fs::try_exists(&final_path_for_cleanup)
.await
.unwrap_or(false)
{
let _ = tokio::fs::remove_file(&final_path_for_cleanup).await;
}
ctx.emit(ProgressEvent::PhaseChanged(Phase::Assembling));
let final_path = assemble_final_file(&metadata, &instruction, ctx).await?;
metadata.finished = true;
persist_metadata(&metadata, &instruction).await?;
remove_all_parts(instruction.download_dir()).await;
ctx.emit(ProgressEvent::Completed {
path: final_path.clone(),
already_complete: false,
});
Ok(final_path)
} else {
let final_path = instruction.final_file_path();
if tokio::fs::try_exists(&final_path).await.unwrap_or(false) {
ctx.emit(ProgressEvent::Completed {
path: final_path.clone(),
already_complete: true,
});
Ok(final_path)
} else {
Err(OdlError::StdIoError {
e: std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("Expected final file not found at {}", final_path.display()),
),
extra_info: None,
})
}
}
}
}
async fn grow_parts(
instruction: &Download,
metadata: &mut crate::download_metadata::DownloadMetadata,
target: u64,
) -> Result<(), OdlError> {
let target_n = target as usize;
let mut on_disk: std::collections::HashMap<String, u64> =
std::collections::HashMap::with_capacity(metadata.parts.len());
for (ulid, _) in metadata.parts.iter() {
let path = instruction.part_path(ulid);
let size = match tokio::fs::metadata(&path).await {
Ok(m) => m.len(),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => 0,
Err(e) => {
return Err(OdlError::StdIoError {
e,
extra_info: Some(format!(
"grow_parts: failed to stat part at {}",
path.display(),
)),
});
}
};
on_disk.insert(ulid.clone(), size);
}
loop {
let unfinished_count = metadata.parts.values().filter(|p| !p.finished).count();
if unfinished_count >= target_n {
break;
}
let candidate = metadata
.parts
.values()
.filter(|p| !p.finished && p.size != Download::UNKNOWN_PART_SIZE)
.filter_map(|p| {
let written = *on_disk.get(&p.ulid).unwrap_or(&0);
Download::compute_split(p.offset, p.size, written, Download::MIN_PART_SIZE)
.map(|split| (p.ulid.clone(), split))
})
.max_by_key(|(_, s)| s.new_right_size);
let Some((ulid, split)) = candidate else {
break;
};
if let Some(p) = metadata.parts.get_mut(&ulid) {
p.size = split.new_left_size;
}
let new_ulid = ulid::Ulid::new().to_string();
metadata.parts.insert(
new_ulid.clone(),
crate::download_metadata::PartDetails {
offset: split.new_right_offset,
size: split.new_right_size,
ulid: new_ulid.clone(),
finished: false,
},
);
on_disk.insert(new_ulid, 0);
}
metadata.max_connections = metadata.parts.len() as u64;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::DownloadOptionsBuilder;
use crate::conflict::FileChangedResolution;
use crate::conflict::NotResumableResolution;
use crate::conflict::ServerConflict;
use crate::download::DownloadBuilder;
use crate::download_metadata::PartDetails;
use crate::error::ConflictError;
use async_trait::async_trait;
use mockito::Matcher;
use mockito::Server;
use std::collections::HashMap;
use std::path::Path;
use std::time::Duration;
use tempfile::tempdir;
use tokio::fs;
use tokio::io::AsyncWriteExt;
fn test_cfg(download_dir: &Path, max_connections: u64) -> Config {
crate::config::ConfigBuilder::default()
.download_dir(download_dir.to_path_buf())
.download(
DownloadOptionsBuilder::default()
.max_connections(max_connections)
.build()
.unwrap(),
)
.build()
.unwrap()
}
fn build_dummy_instruction(download_dir: &Path) -> Download {
let mut parts = HashMap::new();
parts.insert(
"p0".to_string(),
PartDetails {
ulid: "p0".to_string(),
offset: 0,
size: 0,
finished: false,
},
);
DownloadBuilder::default()
.download_dir(download_dir.to_path_buf())
.save_dir(download_dir.to_path_buf())
.filename("dummy".to_string())
.url(Url::parse("http://example.invalid/x").unwrap())
.size(Some(0))
.max_connections(1)
.parts(parts)
.is_resumable(true)
.build()
.unwrap()
}
#[tokio::test]
async fn grow_parts_increases_unfinished_count() -> Result<(), Box<dyn std::error::Error>> {
let tmp = tempdir()?;
let instruction = build_dummy_instruction(tmp.path());
let large_size = Download::MIN_PART_SIZE * 16;
let mut parts = HashMap::new();
parts.insert(
"big".to_string(),
PartDetails {
ulid: "big".to_string(),
offset: 0,
size: large_size,
finished: false,
},
);
let mut metadata = crate::download_metadata::DownloadMetadata {
url: "http://example.invalid/x".to_string(),
filename: "dummy".to_string(),
save_dir: tmp.path().to_string_lossy().into_owned(),
is_resumable: true,
use_server_time: false,
last_modified: None,
last_etag: None,
size: Some(large_size),
checksums: vec![],
requires_auth: false,
requires_basic_auth: false,
headers: HashMap::new(),
max_connections: 1,
parts,
finished: false,
};
grow_parts(&instruction, &mut metadata, 4).await?;
let unfinished = metadata.parts.values().filter(|p| !p.finished).count();
assert_eq!(unfinished, 4, "should grow to 4 unfinished parts");
assert_eq!(metadata.max_connections, 4);
let total: u64 = metadata.parts.values().map(|p| p.size).sum();
assert_eq!(total, large_size, "total coverage must not change");
Ok(())
}
#[tokio::test]
async fn grow_parts_is_noop_when_target_le_current() -> Result<(), Box<dyn std::error::Error>> {
let tmp = tempdir()?;
let instruction = build_dummy_instruction(tmp.path());
let large_size = Download::MIN_PART_SIZE * 16;
let mut parts = HashMap::new();
parts.insert(
"a".to_string(),
PartDetails {
ulid: "a".to_string(),
offset: 0,
size: large_size / 2,
finished: false,
},
);
parts.insert(
"b".to_string(),
PartDetails {
ulid: "b".to_string(),
offset: large_size / 2,
size: large_size / 2,
finished: false,
},
);
let mut metadata = crate::download_metadata::DownloadMetadata {
url: "http://example.invalid/x".to_string(),
filename: "dummy".to_string(),
save_dir: tmp.path().to_string_lossy().into_owned(),
is_resumable: true,
use_server_time: false,
last_modified: None,
last_etag: None,
size: Some(large_size),
checksums: vec![],
requires_auth: false,
requires_basic_auth: false,
headers: HashMap::new(),
max_connections: 2,
parts,
finished: false,
};
grow_parts(&instruction, &mut metadata, 2).await?;
assert_eq!(metadata.parts.len(), 2);
Ok(())
}
#[test]
fn dynamic_split_default_is_true() {
let opts = DownloadOptionsBuilder::default().build().unwrap();
assert!(opts.dynamic_split());
}
#[test]
fn dynamic_split_can_be_disabled_via_builder() {
let opts = DownloadOptionsBuilder::default()
.dynamic_split(false)
.build()
.unwrap();
assert!(!opts.dynamic_split());
}
#[test]
fn rampup_defaults() {
let opts = DownloadOptionsBuilder::default().build().unwrap();
assert!(opts.rampup());
assert_eq!(opts.rampup_batch_size(), 2);
assert_eq!(opts.rampup_delay_min(), Duration::from_millis(300));
assert_eq!(opts.rampup_delay_max(), Duration::from_millis(1000));
}
#[test]
fn rampup_builder_clamps_zero_batch_size() {
let opts = DownloadOptionsBuilder::default()
.rampup_batch_size(0)
.build()
.unwrap();
assert!(opts.rampup_batch_size() >= 1);
}
#[test]
fn rampup_builder_rejects_inverted_delays() {
use crate::config::DownloadOptionsBuilderError;
let err = DownloadOptionsBuilder::default()
.rampup(true)
.rampup_delay_min(Duration::from_millis(1000))
.rampup_delay_max(Duration::from_millis(500))
.build()
.expect_err("expected error");
assert!(matches!(
err,
DownloadOptionsBuilderError::ValidationError(_)
));
}
#[test]
fn rampup_inverted_delays_ignored_when_disabled() {
let opts = DownloadOptionsBuilder::default()
.rampup(false)
.rampup_delay_min(Duration::from_millis(1000))
.rampup_delay_max(Duration::from_millis(500))
.build()
.unwrap();
assert!(!opts.rampup());
}
#[tokio::test]
async fn grow_parts_preserves_partial_progress() -> Result<(), Box<dyn std::error::Error>> {
let tmp = tempdir()?;
let instruction = build_dummy_instruction(tmp.path());
let total_size = Download::MIN_PART_SIZE * 16;
let already_written = Download::MIN_PART_SIZE * 3;
let part_path = instruction.part_path("big");
tokio::fs::write(&part_path, vec![0u8; already_written as usize]).await?;
let mut parts = HashMap::new();
parts.insert(
"big".to_string(),
PartDetails {
ulid: "big".to_string(),
offset: 0,
size: total_size,
finished: false,
},
);
let mut metadata = crate::download_metadata::DownloadMetadata {
url: "http://example.invalid/x".to_string(),
filename: "dummy".to_string(),
save_dir: tmp.path().to_string_lossy().into_owned(),
is_resumable: true,
use_server_time: false,
last_modified: None,
last_etag: None,
size: Some(total_size),
checksums: vec![],
requires_auth: false,
requires_basic_auth: false,
headers: HashMap::new(),
max_connections: 1,
parts,
finished: false,
};
grow_parts(&instruction, &mut metadata, 2).await?;
let original = metadata.parts.get("big").expect("original part survives");
assert_eq!(original.offset, 0);
assert!(
original.size >= already_written,
"split point must not invalidate already-downloaded bytes (size={}, written={})",
original.size,
already_written,
);
let total: u64 = metadata.parts.values().map(|p| p.size).sum();
assert_eq!(total, total_size);
Ok(())
}
struct AlwaysAbortResolver;
#[async_trait]
impl ServerConflictResolver for AlwaysAbortResolver {
async fn resolve_file_changed(&self, _: &Download) -> FileChangedResolution {
FileChangedResolution::Abort
}
async fn resolve_not_resumable(&self, _: &Download) -> NotResumableResolution {
NotResumableResolution::Abort
}
}
struct AlwaysReplaceResolver;
#[async_trait]
impl SaveConflictResolver for AlwaysReplaceResolver {
async fn final_file_exists(
&self,
_: &Download,
) -> crate::conflict::FinalFileExistsResolution {
crate::conflict::FinalFileExistsResolution::ReplaceAndContinue
}
async fn same_download_exists(
&self,
_: &Download,
) -> crate::conflict::SameDownloadExistsResolution {
crate::conflict::SameDownloadExistsResolution::Resume
}
}
#[tokio::test]
async fn test_download_manager_multipart_download() -> Result<(), Box<dyn std::error::Error>> {
let file_content = b"HelloWorldThisIsATestFile";
let part1 = &file_content[..10]; let part2 = &file_content[10..];
let mut server = Server::new_async().await;
let url = server.url();
let head_mock = server
.mock("HEAD", "/testfile")
.with_status(200)
.with_header("content-length", &file_content.len().to_string())
.with_header("accept-ranges", "bytes")
.with_header("etag", "testetag")
.with_header("last-modified", "Wed, 21 Oct 2015 07:28:00 GMT")
.create_async()
.await;
let get_mock1 = server
.mock("GET", "/testfile")
.match_header("range", Matcher::Exact("bytes=0-9".into()))
.with_status(206)
.with_body(part1)
.create_async()
.await;
let get_mock2 = server
.mock("GET", "/testfile")
.match_header(
"range",
Matcher::Exact(format!("bytes=10-{}", file_content.len() - 1)),
)
.with_status(206)
.with_body(part2)
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let cfg = test_cfg(tmp_data_dir.path(), 2);
let dlm = DownloadManager::new(cfg);
let save_resolver = AlwaysReplaceResolver {};
let instruction = dlm
.evaluate(EvaluateRequest::new(
Url::parse(&format!("{}/testfile", url)).unwrap(),
tmp_save_dir.path().to_path_buf(),
&save_resolver,
))
.await?;
let instruction = DownloadBuilder::default()
.download_dir(instruction.download_dir().clone())
.save_dir(instruction.save_dir().clone())
.filename(instruction.filename().to_string())
.url(instruction.url().clone())
.size(Some(file_content.len() as u64))
.max_connections(2)
.parts({
let mut parts = HashMap::new();
parts.insert(
"part1".to_string(),
PartDetails {
ulid: "part1".to_string(),
offset: 0,
size: 10,
finished: false,
},
);
parts.insert(
"part2".to_string(),
PartDetails {
ulid: "part2".to_string(),
offset: 10,
size: (file_content.len() - 10) as u64,
finished: false,
},
);
parts
})
.is_resumable(true)
.build()
.unwrap();
let resolver = AlwaysAbortResolver {};
let final_path = dlm
.download(DownloadRequest::new(instruction, &resolver))
.await?;
let result = fs::read(&final_path).await?;
assert_eq!(result, file_content);
head_mock.assert_async().await;
get_mock1.assert_async().await;
get_mock2.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_save_conflict_final_file_exists_abort() -> Result<(), Box<dyn std::error::Error>>
{
let mut server = Server::new_async().await;
let base = server.url();
let head_mock = server
.mock("HEAD", "/file_abort")
.with_status(200)
.with_header("content-length", "1")
.with_header("accept-ranges", "bytes")
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let filename = "file_abort";
let final_path = tmp_save_dir.path().join(filename);
tokio::fs::write(&final_path, b"x").await?;
let cfg = test_cfg(tmp_data_dir.path(), 1);
let dlm = DownloadManager::new(cfg);
struct AbortFinalResolver;
#[async_trait]
impl SaveConflictResolver for AbortFinalResolver {
async fn final_file_exists(
&self,
_: &Download,
) -> crate::conflict::FinalFileExistsResolution {
crate::conflict::FinalFileExistsResolution::Abort
}
async fn same_download_exists(
&self,
_: &Download,
) -> crate::conflict::SameDownloadExistsResolution {
crate::conflict::SameDownloadExistsResolution::Resume
}
}
let resolver = AbortFinalResolver {};
let result = dlm
.evaluate(EvaluateRequest::new(
Url::parse(&format!("{}/file_abort", base)).unwrap(),
tmp_save_dir.path().to_path_buf(),
&resolver,
))
.await;
assert!(matches!(
result,
Err(OdlError::Conflict(ConflictError::Save {
conflict: crate::conflict::SaveConflict::FinalFileExists
}))
));
head_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_save_conflict_final_file_exists_add_number()
-> Result<(), Box<dyn std::error::Error>> {
let mut server = Server::new_async().await;
let base = server.url();
let head_mock = server
.mock("HEAD", "/file_add")
.with_status(200)
.with_header("content-length", "1")
.with_header("accept-ranges", "bytes")
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let filename = "file_add";
let final_path = tmp_save_dir.path().join(filename);
tokio::fs::write(&final_path, b"x").await?;
let cfg = test_cfg(tmp_data_dir.path(), 1);
let dlm = DownloadManager::new(cfg);
struct AddNumberResolver;
#[async_trait]
impl SaveConflictResolver for AddNumberResolver {
async fn final_file_exists(
&self,
_: &Download,
) -> crate::conflict::FinalFileExistsResolution {
crate::conflict::FinalFileExistsResolution::AddNumberToNameAndContinue
}
async fn same_download_exists(
&self,
_: &Download,
) -> crate::conflict::SameDownloadExistsResolution {
crate::conflict::SameDownloadExistsResolution::Resume
}
}
let resolver = AddNumberResolver {};
let instruction = dlm
.evaluate(EvaluateRequest::new(
Url::parse(&format!("{}/file_add", base)).unwrap(),
tmp_save_dir.path().to_path_buf(),
&resolver,
))
.await?;
assert_eq!(instruction.filename(), "file_add_2");
head_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_download_manager_single_part_download() -> Result<(), Box<dyn std::error::Error>>
{
let file_content = b"SinglePartFileContent";
let part = &file_content[..];
let mut server = Server::new_async().await;
let url = server.url();
let head_mock = server
.mock("HEAD", "/singlefile")
.with_status(200)
.with_header("content-length", &file_content.len().to_string())
.with_header("accept-ranges", "bytes")
.with_header("etag", "singleetag")
.with_header("last-modified", "Thu, 22 Oct 2015 07:28:00 GMT")
.create_async()
.await;
let get_mock = server
.mock("GET", "/singlefile")
.match_header(
"range",
Matcher::Exact(format!("bytes=0-{}", file_content.len() - 1)),
)
.with_status(206)
.with_body(part)
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let cfg = test_cfg(tmp_data_dir.path(), 1);
let dlm = DownloadManager::new(cfg);
let save_resolver = AlwaysReplaceResolver {};
let instruction = dlm
.evaluate(EvaluateRequest::new(
Url::parse(&format!("{}/singlefile", url)).unwrap(),
tmp_save_dir.path().to_path_buf(),
&save_resolver,
))
.await?;
let instruction = DownloadBuilder::default()
.download_dir(instruction.download_dir().clone())
.save_dir(instruction.save_dir().clone())
.filename(instruction.filename().to_string())
.url(instruction.url().clone())
.size(Some(file_content.len() as u64))
.max_connections(1)
.parts({
let mut parts = HashMap::new();
parts.insert(
"part1".to_string(),
PartDetails {
ulid: "part1".to_string(),
offset: 0,
size: file_content.len() as u64,
finished: false,
},
);
parts
})
.is_resumable(true)
.build()
.unwrap();
let resolver = AlwaysAbortResolver {};
let final_path = dlm
.download(DownloadRequest::new(instruction, &resolver))
.await?;
let result = fs::read(&final_path).await?;
assert_eq!(result, file_content);
head_mock.assert_async().await;
get_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_download_manager_multipart_not_resumable_download()
-> Result<(), Box<dyn std::error::Error>> {
let file_content = b"NonResumableMultipartFile";
let mut server = Server::new_async().await;
let url = server.url();
let head_mock = server
.mock("HEAD", "/nonresumablefile")
.with_status(200)
.with_header("content-length", &file_content.len().to_string())
.with_header("etag", "nonresumableetag")
.with_header("last-modified", "Fri, 23 Oct 2015 07:28:00 GMT")
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let cfg = test_cfg(tmp_data_dir.path(), 2);
let dlm = DownloadManager::new(cfg);
let save_resolver = AlwaysReplaceResolver {};
let instruction = dlm
.evaluate(EvaluateRequest::new(
Url::parse(&format!("{}/nonresumablefile", url)).unwrap(),
tmp_save_dir.path().to_path_buf(),
&save_resolver,
))
.await?;
let instruction = DownloadBuilder::default()
.download_dir(instruction.download_dir().clone())
.save_dir(instruction.save_dir().clone())
.filename(instruction.filename().to_string())
.url(instruction.url().clone())
.size(Some(file_content.len() as u64))
.max_connections(2)
.parts({
let mut parts = HashMap::new();
parts.insert(
"part1".to_string(),
PartDetails {
ulid: "part1".to_string(),
offset: 0,
size: 10,
finished: false,
},
);
parts.insert(
"part2".to_string(),
PartDetails {
ulid: "part2".to_string(),
offset: 10,
size: (file_content.len() - 10) as u64,
finished: false,
},
);
parts
})
.is_resumable(false)
.build()
.unwrap();
struct AssertTestResolver;
#[async_trait]
impl ServerConflictResolver for AssertTestResolver {
async fn resolve_file_changed(&self, _: &Download) -> FileChangedResolution {
FileChangedResolution::Abort
}
async fn resolve_not_resumable(&self, _: &Download) -> NotResumableResolution {
NotResumableResolution::Abort
}
}
let resolver = AssertTestResolver {};
let result = dlm
.download(DownloadRequest::new(instruction, &resolver))
.await;
assert!(matches!(
result,
Err(OdlError::Conflict(ConflictError::Server {
conflict: ServerConflict::NotResumable
}))
));
head_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_download_manager_multipart_not_resumable_restart_download()
-> Result<(), Box<dyn std::error::Error>> {
let file_content = b"NonResumableMultipartFile";
let part = &file_content[..];
let mut server = Server::new_async().await;
let url = server.url();
let head_mock = server
.mock("HEAD", "/nonresumablefile_restart")
.with_status(200)
.with_header("content-length", &file_content.len().to_string())
.with_header("etag", "nonresumableetag")
.with_header("last-modified", "Fri, 23 Oct 2015 07:28:00 GMT")
.create_async()
.await;
let get_mock = server
.mock("GET", "/nonresumablefile_restart")
.match_header(
"range",
Matcher::Exact(format!("bytes=0-{}", file_content.len() - 1)),
)
.with_status(206)
.with_body(part)
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let cfg = test_cfg(tmp_data_dir.path(), 2);
let dlm = DownloadManager::new(cfg);
let save_resolver = AlwaysReplaceResolver {};
let instruction = dlm
.evaluate(EvaluateRequest::new(
Url::parse(&format!("{}/nonresumablefile_restart", url)).unwrap(),
tmp_save_dir.path().to_path_buf(),
&save_resolver,
))
.await?;
let instruction = DownloadBuilder::default()
.download_dir(instruction.download_dir().clone())
.save_dir(instruction.save_dir().clone())
.filename(instruction.filename().to_string())
.url(instruction.url().clone())
.size(Some(file_content.len() as u64))
.max_connections(2)
.parts({
let mut parts = std::collections::HashMap::new();
parts.insert(
"part1".to_string(),
PartDetails {
ulid: "part1".to_string(),
offset: 0,
size: 10,
finished: false,
},
);
parts.insert(
"part2".to_string(),
PartDetails {
ulid: "part2".to_string(),
offset: 10,
size: (file_content.len() - 10) as u64,
finished: false,
},
);
parts
})
.is_resumable(false)
.build()
.unwrap();
struct AssertTestResolver;
#[async_trait]
impl ServerConflictResolver for AssertTestResolver {
async fn resolve_file_changed(&self, _: &Download) -> FileChangedResolution {
FileChangedResolution::Restart
}
async fn resolve_not_resumable(&self, _: &Download) -> NotResumableResolution {
NotResumableResolution::Restart
}
}
let resolver = AssertTestResolver {};
let final_path = dlm
.download(DownloadRequest::new(instruction, &resolver))
.await?;
let result = fs::read(&final_path).await?;
assert_eq!(result, file_content);
head_mock.assert_async().await;
get_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_download_manager_zero_byte_single_part_download()
-> Result<(), Box<dyn std::error::Error>> {
let file_content = b"";
let mut server = Server::new_async().await;
let url = server.url();
let head_mock = server
.mock("HEAD", "/zerofile")
.with_status(200)
.with_header("content-length", "0")
.with_header("accept-ranges", "bytes")
.with_header("etag", "zeroetag")
.with_header("last-modified", "Sat, 24 Oct 2015 07:28:00 GMT")
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let cfg = test_cfg(tmp_data_dir.path(), 1);
let dlm = DownloadManager::new(cfg);
let save_resolver = AlwaysReplaceResolver {};
let instruction = dlm
.evaluate(EvaluateRequest::new(
Url::parse(&format!("{}/zerofile", url)).unwrap(),
tmp_save_dir.path().to_path_buf(),
&save_resolver,
))
.await?;
let instruction = DownloadBuilder::default()
.download_dir(instruction.download_dir().clone())
.save_dir(instruction.save_dir().clone())
.filename(instruction.filename().to_string())
.url(instruction.url().clone())
.size(Some(0))
.max_connections(1)
.parts({
let mut parts = std::collections::HashMap::new();
parts.insert(
"part1".to_string(),
PartDetails {
ulid: "part1".to_string(),
offset: 0,
size: 0,
finished: false,
},
);
parts
})
.is_resumable(true)
.build()
.unwrap();
let resolver = AlwaysAbortResolver {};
let final_path = dlm
.download(DownloadRequest::new(instruction, &resolver))
.await?;
let result = fs::read(&final_path).await?;
assert_eq!(result, file_content);
head_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_download_part_resumes_with_correct_range()
-> Result<(), Box<dyn std::error::Error>> {
let file_content = b"PartialDownloadTestFile";
let part_offset = 0;
let part_size = file_content.len() as u64;
let already_downloaded = 7;
let mut server = Server::new_async().await;
let url = server.url();
let expected_range = format!(
"bytes={}-{}",
part_offset + already_downloaded,
part_offset + part_size - 1
);
let get_mock = server
.mock("GET", "/partialfile")
.match_header("range", Matcher::Exact(expected_range.clone()))
.with_status(206)
.with_body(&file_content[already_downloaded as usize..])
.create_async()
.await;
let tmp_dir = tempdir()?;
let download_dir = tmp_dir.path().join("partial");
fs::create_dir_all(&download_dir).await?;
let part_path = download_dir.join("part1.part");
{
let mut f = fs::File::create(&part_path).await?;
f.write_all(&file_content[..already_downloaded as usize])
.await?;
}
let part_details = PartDetails {
ulid: "part1".to_string(),
offset: part_offset,
size: part_size,
finished: false,
};
let mut parts_map = HashMap::new();
parts_map.insert(part_details.ulid.clone(), part_details.clone());
let instruction = DownloadBuilder::default()
.download_dir(download_dir.clone())
.save_dir(tmp_dir.path().to_path_buf())
.filename("partialfile.bin".to_string())
.url(Url::parse(&format!("{}/partialfile", url)).unwrap())
.is_resumable(true)
.max_connections(1)
.size(Some(part_size))
.parts(parts_map)
.build()
.unwrap();
let metadata = instruction.as_metadata();
let client = reqwest::Client::builder().build()?;
let retry_policy = crate::retry_policies::FixedThenExponentialRetry {
max_n_retries: 6,
wait_time: std::time::Duration::from_millis(100),
n_fixed_retries: 3,
};
let downloader = Downloader::new(
Arc::new(instruction.clone()),
metadata,
client,
false,
None,
true,
RampupConfig::disabled(),
retry_policy,
DownloadContext::new(),
);
let updated_metadata = downloader.run().await?;
assert!(
updated_metadata
.parts
.get("part1")
.map(|p| p.finished)
.unwrap_or(false)
);
let result = tokio::fs::read(&part_path).await?;
assert_eq!(result, file_content);
get_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_download_manager_custom_user_agent() -> Result<(), Box<dyn std::error::Error>> {
let file_content = b"UserAgentTestFile";
let mut server = Server::new_async().await;
let url = server.url();
let custom_ua = "MyCustomUserAgent/1.0";
let head_mock = server
.mock("HEAD", "/useragentfile")
.match_header("user-agent", Matcher::Exact(custom_ua.into()))
.with_status(200)
.with_header("content-length", &file_content.len().to_string())
.with_header("accept-ranges", "bytes")
.with_header("etag", "uaetag")
.with_header("last-modified", "Sun, 25 Oct 2015 07:28:00 GMT")
.create_async()
.await;
let get_mock = server
.mock("GET", "/useragentfile")
.match_header("user-agent", Matcher::Exact(custom_ua.into()))
.match_header(
"range",
Matcher::Exact(format!("bytes=0-{}", file_content.len() - 1)),
)
.with_status(206)
.with_body(file_content)
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let cfg = crate::config::ConfigBuilder::default()
.download_dir(tmp_data_dir.path().to_path_buf())
.download(
DownloadOptionsBuilder::default()
.max_connections(1)
.user_agent(Some(custom_ua.to_string()))
.randomize_user_agent(false)
.build()
.unwrap(),
)
.build()
.unwrap();
let dlm = DownloadManager::new(cfg);
let save_resolver = AlwaysReplaceResolver {};
let instruction = dlm
.evaluate(EvaluateRequest::new(
Url::parse(&format!("{}/useragentfile", url)).unwrap(),
tmp_save_dir.path().to_path_buf(),
&save_resolver,
))
.await?;
let instruction = DownloadBuilder::default()
.download_dir(instruction.download_dir().clone())
.save_dir(instruction.save_dir().clone())
.filename(instruction.filename().to_string())
.url(instruction.url().clone())
.size(Some(file_content.len() as u64))
.max_connections(1)
.parts({
let mut parts = std::collections::HashMap::new();
parts.insert(
"part1".to_string(),
PartDetails {
ulid: "part1".to_string(),
offset: 0,
size: file_content.len() as u64,
finished: false,
},
);
parts
})
.is_resumable(true)
.build()
.unwrap();
let resolver = AlwaysAbortResolver {};
let final_path = dlm
.download(DownloadRequest::new(instruction, &resolver))
.await?;
let result = tokio::fs::read(&final_path).await?;
assert_eq!(result, file_content);
head_mock.assert_async().await;
get_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_per_job_options_override_user_agent_end_to_end()
-> Result<(), Box<dyn std::error::Error>> {
let file_content = b"per-job override payload";
let mut server = Server::new_async().await;
let url = server.url();
let manager_ua = "ManagerUA/1.0";
let job_ua = "PerJobUA/2.0";
let head_mock = server
.mock("HEAD", "/perjob")
.match_header("user-agent", Matcher::Exact(job_ua.into()))
.with_status(200)
.with_header("content-length", &file_content.len().to_string())
.with_header("accept-ranges", "bytes")
.with_header("etag", "pjetag")
.create_async()
.await;
let get_mock = server
.mock("GET", "/perjob")
.match_header("user-agent", Matcher::Exact(job_ua.into()))
.match_header(
"range",
Matcher::Exact(format!("bytes=0-{}", file_content.len() - 1)),
)
.with_status(206)
.with_body(file_content)
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let cfg = crate::config::ConfigBuilder::default()
.download_dir(tmp_data_dir.path().to_path_buf())
.download(
DownloadOptionsBuilder::default()
.max_connections(1)
.user_agent(Some(manager_ua.to_string()))
.build()
.unwrap(),
)
.build()
.unwrap();
let dlm = DownloadManager::new(cfg);
let job_opts = DownloadOptionsBuilder::default()
.max_connections(1)
.user_agent(Some(job_ua.to_string()))
.build()
.unwrap();
let save_resolver = AlwaysReplaceResolver {};
let instruction = dlm
.evaluate(
EvaluateRequest::new(
Url::parse(&format!("{}/perjob", url)).unwrap(),
tmp_save_dir.path().to_path_buf(),
&save_resolver,
)
.options(&job_opts),
)
.await?;
let instruction = DownloadBuilder::default()
.download_dir(instruction.download_dir().clone())
.save_dir(instruction.save_dir().clone())
.filename(instruction.filename().to_string())
.url(instruction.url().clone())
.size(Some(file_content.len() as u64))
.max_connections(1)
.parts({
let mut parts = std::collections::HashMap::new();
parts.insert(
"part1".to_string(),
PartDetails {
ulid: "part1".to_string(),
offset: 0,
size: file_content.len() as u64,
finished: false,
},
);
parts
})
.is_resumable(true)
.build()
.unwrap();
let resolver = AlwaysAbortResolver {};
let final_path = dlm
.download(DownloadRequest::new(instruction, &resolver).options(&job_opts))
.await?;
let result = fs::read(&final_path).await?;
assert_eq!(result, file_content);
head_mock.assert_async().await;
get_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_e2e_evaluate_download_assemble_with_checksum()
-> Result<(), Box<dyn std::error::Error>> {
use base64::Engine;
use sha2::{Digest, Sha256};
let file_content = b"E2E full pipeline payload: evaluate -> download -> assemble -> verify";
let mut hasher = Sha256::new();
hasher.update(file_content);
let sha256_b64 = base64::engine::general_purpose::STANDARD.encode(hasher.finalize());
let repr_digest_value = format!("sha-256=:{}:", sha256_b64);
let mut server = Server::new_async().await;
let url = server.url();
let head_mock = server
.mock("HEAD", "/payload.bin")
.with_status(200)
.with_header("content-length", &file_content.len().to_string())
.with_header("accept-ranges", "bytes")
.with_header("etag", "e2eetag")
.with_header("last-modified", "Tue, 27 Oct 2015 07:28:00 GMT")
.with_header("Repr-Digest", &repr_digest_value)
.create_async()
.await;
let get_mock = server
.mock("GET", "/payload.bin")
.match_header(
"range",
Matcher::Exact(format!("bytes=0-{}", file_content.len() - 1)),
)
.with_status(206)
.with_body(file_content)
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let cfg = test_cfg(tmp_data_dir.path(), 1);
let dlm = DownloadManager::new(cfg);
let save_resolver = AlwaysReplaceResolver {};
let instruction = dlm
.evaluate(EvaluateRequest::new(
Url::parse(&format!("{}/payload.bin", url)).unwrap(),
tmp_save_dir.path().to_path_buf(),
&save_resolver,
))
.await?;
assert!(
!instruction.as_metadata().checksums.is_empty(),
"evaluate did not extract checksum from Repr-Digest"
);
assert_eq!(instruction.size(), Some(file_content.len() as u64));
let resolver = AlwaysAbortResolver {};
let final_path = dlm
.download(DownloadRequest::new(instruction, &resolver))
.await?;
let on_disk = fs::read(&final_path).await?;
assert_eq!(on_disk, file_content, "final file content mismatch");
let mut hasher = Sha256::new();
hasher.update(&on_disk);
let actual_b64 = base64::engine::general_purpose::STANDARD.encode(hasher.finalize());
assert_eq!(actual_b64, sha256_b64);
head_mock.assert_async().await;
get_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_e2e_download_fails_on_checksum_mismatch() -> Result<(), Box<dyn std::error::Error>>
{
let file_content = b"payload-served-by-server";
let bogus_repr_digest = "sha-256=:AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=:";
let mut server = Server::new_async().await;
let url = server.url();
let head_mock = server
.mock("HEAD", "/bad.bin")
.with_status(200)
.with_header("content-length", &file_content.len().to_string())
.with_header("accept-ranges", "bytes")
.with_header("Repr-Digest", bogus_repr_digest)
.create_async()
.await;
let get_mock = server
.mock("GET", "/bad.bin")
.match_header(
"range",
Matcher::Exact(format!("bytes=0-{}", file_content.len() - 1)),
)
.with_status(206)
.with_body(file_content)
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let cfg = test_cfg(tmp_data_dir.path(), 1);
let dlm = DownloadManager::new(cfg);
let save_resolver = AlwaysReplaceResolver {};
let instruction = dlm
.evaluate(EvaluateRequest::new(
Url::parse(&format!("{}/bad.bin", url)).unwrap(),
tmp_save_dir.path().to_path_buf(),
&save_resolver,
))
.await?;
let resolver = AlwaysAbortResolver {};
let result = dlm
.download(DownloadRequest::new(instruction, &resolver))
.await;
assert!(
matches!(
result,
Err(OdlError::Conflict(ConflictError::ChecksumMismatch { .. }))
),
"expected ChecksumMismatch, got {:?}",
result
);
head_mock.assert_async().await;
get_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_e2e_multipart_evaluate_download_assemble_with_checksum()
-> Result<(), Box<dyn std::error::Error>> {
use base64::Engine;
use sha2::{Digest, Sha256};
use rand::{RngExt, SeedableRng, rngs::StdRng};
let size: usize = 900 * 1024;
let mut rng = StdRng::seed_from_u64(0x00C0_FFEE_F00D);
let file_content: Vec<u8> = (0..size).map(|_| rng.random()).collect();
let mut hasher = Sha256::new();
hasher.update(&file_content);
let sha256_b64 = base64::engine::general_purpose::STANDARD.encode(hasher.finalize());
let repr_digest_value = format!("sha-256=:{}:", sha256_b64);
let mut server = Server::new_async().await;
let url = server.url();
let head_mock = server
.mock("HEAD", "/big.bin")
.with_status(200)
.with_header("content-length", &size.to_string())
.with_header("accept-ranges", "bytes")
.with_header("etag", "bigetag")
.with_header("Repr-Digest", &repr_digest_value)
.create_async()
.await;
let part_size = size / 3;
let mut get_mocks = Vec::new();
for i in 0..3 {
let start = i * part_size;
let end = if i == 2 {
size - 1
} else {
start + part_size - 1
};
let body = file_content[start..=end].to_vec();
let m = server
.mock("GET", "/big.bin")
.match_header("range", Matcher::Exact(format!("bytes={}-{}", start, end)))
.with_status(206)
.with_body(body)
.create_async()
.await;
get_mocks.push(m);
}
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let cfg = test_cfg(tmp_data_dir.path(), 3);
let dlm = DownloadManager::new(cfg);
let save_resolver = AlwaysReplaceResolver {};
let instruction = dlm
.evaluate(EvaluateRequest::new(
Url::parse(&format!("{}/big.bin", url)).unwrap(),
tmp_save_dir.path().to_path_buf(),
&save_resolver,
))
.await?;
let metadata = instruction.as_metadata();
assert_eq!(
metadata.parts.len(),
3,
"expected 3 parts, got {}",
metadata.parts.len()
);
let mut offsets: Vec<(u64, u64)> = metadata
.parts
.values()
.map(|p| (p.offset, p.size))
.collect();
offsets.sort_by_key(|(o, _)| *o);
let mut covered: u64 = 0;
for (off, sz) in &offsets {
assert_eq!(*off, covered);
covered += sz;
}
assert_eq!(covered, size as u64);
assert!(!metadata.checksums.is_empty());
let resolver = AlwaysAbortResolver {};
let final_path = dlm
.download(DownloadRequest::new(instruction, &resolver))
.await?;
let on_disk = fs::read(&final_path).await?;
assert_eq!(on_disk.len(), file_content.len());
assert_eq!(on_disk, file_content, "assembled file bytes mismatch");
let mut hasher = Sha256::new();
hasher.update(&on_disk);
let actual_b64 = base64::engine::general_purpose::STANDARD.encode(hasher.finalize());
assert_eq!(actual_b64, sha256_b64);
head_mock.assert_async().await;
for m in &get_mocks {
m.assert_async().await;
}
Ok(())
}
#[tokio::test]
async fn test_download_manager_random_user_agent() -> Result<(), Box<dyn std::error::Error>> {
let file_content = b"RandomUserAgentTestFile";
let mut server = Server::new_async().await;
let url = server.url();
let head_mock = server
.mock("HEAD", "/randomua")
.match_header("user-agent", Matcher::Any)
.with_status(200)
.with_header("content-length", &file_content.len().to_string())
.with_header("accept-ranges", "bytes")
.with_header("etag", "randomuaetag")
.with_header("last-modified", "Mon, 26 Oct 2015 07:28:00 GMT")
.create_async()
.await;
let get_mock = server
.mock("GET", "/randomua")
.match_header("user-agent", Matcher::Any)
.match_header(
"range",
Matcher::Exact(format!("bytes=0-{}", file_content.len() - 1)),
)
.with_status(206)
.with_body(file_content)
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let cfg = crate::config::ConfigBuilder::default()
.download_dir(tmp_data_dir.path().to_path_buf())
.download(
DownloadOptionsBuilder::default()
.max_connections(1)
.randomize_user_agent(true)
.build()
.unwrap(),
)
.build()
.unwrap();
let dlm = DownloadManager::new(cfg);
let save_resolver = AlwaysReplaceResolver {};
let instruction = dlm
.evaluate(EvaluateRequest::new(
Url::parse(&format!("{}/randomua", url)).unwrap(),
tmp_save_dir.path().to_path_buf(),
&save_resolver,
))
.await?;
let instruction = DownloadBuilder::default()
.download_dir(instruction.download_dir().clone())
.save_dir(instruction.save_dir().clone())
.filename(instruction.filename().to_string())
.url(instruction.url().clone())
.size(Some(file_content.len() as u64))
.max_connections(1)
.parts({
let mut parts = std::collections::HashMap::new();
parts.insert(
"part1".to_string(),
PartDetails {
ulid: "part1".to_string(),
offset: 0,
size: file_content.len() as u64,
finished: false,
},
);
parts
})
.is_resumable(true)
.build()
.unwrap();
let resolver = AlwaysAbortResolver {};
let final_path = dlm
.download(DownloadRequest::new(instruction, &resolver))
.await?;
let result = tokio::fs::read(&final_path).await?;
assert_eq!(result, file_content);
head_mock.assert_async().await;
get_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_resumes_assembly_after_interrupt_with_no_server_checksum()
-> Result<(), Box<dyn std::error::Error>> {
let file_content = b"AssemblyResumePayload-NoChecksumScenario-0123456789abcdefghijklmnop";
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let download_dir = tmp_data_dir.path().join("payload-bin");
fs::create_dir_all(&download_dir).await?;
let part1_ulid = "part1ulid_resume_test";
let part2_ulid = "part2ulid_resume_test";
let split: usize = 20;
let mut parts = HashMap::new();
parts.insert(
part1_ulid.to_string(),
PartDetails {
ulid: part1_ulid.to_string(),
offset: 0,
size: split as u64,
finished: true,
},
);
parts.insert(
part2_ulid.to_string(),
PartDetails {
ulid: part2_ulid.to_string(),
offset: split as u64,
size: (file_content.len() - split) as u64,
finished: true,
},
);
let instruction = DownloadBuilder::default()
.download_dir(download_dir.clone())
.save_dir(tmp_save_dir.path().to_path_buf())
.filename("payload.bin".to_string())
.url(Url::parse("http://example.invalid/payload.bin").unwrap())
.size(Some(file_content.len() as u64))
.max_connections(1)
.parts(parts)
.is_resumable(true)
.build()
.unwrap();
fs::write(instruction.part_path(part1_ulid), &file_content[..split]).await?;
fs::write(instruction.part_path(part2_ulid), &file_content[split..]).await?;
let metadata = instruction.as_metadata();
assert!(
!metadata.finished,
"precondition: metadata.finished should be false (mid-assembly state)"
);
assert!(
metadata.checksums.is_empty(),
"precondition: no server checksum (the bug condition)"
);
persist_metadata(&metadata, &instruction).await?;
let final_path = instruction.final_file_path();
fs::write(&final_path, vec![0u8; file_content.len()]).await?;
let cfg = test_cfg(tmp_data_dir.path(), 1);
let dlm = DownloadManager::new(cfg);
let resolver = AlwaysAbortResolver {};
let result_path = dlm
.download(DownloadRequest::new(instruction.clone(), &resolver))
.await?;
assert_eq!(result_path, final_path);
let on_disk = fs::read(&final_path).await?;
assert_eq!(
on_disk, file_content,
"final file must be re-assembled from parts, not left as the partial zero-padded carcass"
);
let bytes = fs::read(instruction.metadata_path()).await?;
let persisted = {
use prost::Message;
crate::download_metadata::DownloadMetadata::decode_length_delimited(&*bytes)
.expect("decode metadata")
};
assert!(
persisted.finished,
"post-condition: metadata.finished must be true after successful assembly"
);
Ok(())
}
#[tokio::test]
async fn test_recovers_from_crash_between_assembly_and_finished_persist()
-> Result<(), Box<dyn std::error::Error>> {
use base64::Engine;
use sha2::{Digest, Sha256};
let file_content = b"FastPathRecovery-AssemblyDone-FinishedFlagNotPersisted-XYZ";
let mut hasher = Sha256::new();
hasher.update(file_content);
let sha256_b64 = base64::engine::general_purpose::STANDARD.encode(hasher.finalize());
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let download_dir = tmp_data_dir.path().join("payload-fastpath");
fs::create_dir_all(&download_dir).await?;
let part_ulid = "fastpath_part";
let mut parts = HashMap::new();
parts.insert(
part_ulid.to_string(),
PartDetails {
ulid: part_ulid.to_string(),
offset: 0,
size: file_content.len() as u64,
finished: true,
},
);
let checksums = vec![crate::hash::HashDigest::SHA256(
sha256_b64.clone(),
crate::hash::HashEncoding::Base64,
)];
let instruction = DownloadBuilder::default()
.download_dir(download_dir.clone())
.save_dir(tmp_save_dir.path().to_path_buf())
.filename("payload-fastpath.bin".to_string())
.url(Url::parse("http://example.invalid/payload-fastpath.bin").unwrap())
.size(Some(file_content.len() as u64))
.max_connections(1)
.parts(parts)
.is_resumable(true)
.checksums(checksums)
.build()
.unwrap();
fs::write(instruction.part_path(part_ulid), file_content).await?;
fs::write(instruction.final_file_path(), file_content).await?;
let metadata = instruction.as_metadata();
assert!(!metadata.checksums.is_empty());
persist_metadata(&metadata, &instruction).await?;
let part_path = instruction.part_path(part_ulid);
let pre_meta = std::fs::metadata(&part_path)?;
let pre_final_meta = std::fs::metadata(instruction.final_file_path())?;
let cfg = test_cfg(tmp_data_dir.path(), 1);
let dlm = DownloadManager::new(cfg);
let resolver = AlwaysAbortResolver {};
let final_path = dlm
.download(DownloadRequest::new(instruction.clone(), &resolver))
.await?;
let on_disk = fs::read(&final_path).await?;
assert_eq!(
on_disk, file_content,
"fast-path must preserve the already-correct final file"
);
let post_final_meta = std::fs::metadata(&final_path)?;
assert_eq!(
pre_final_meta.modified()?,
post_final_meta.modified()?,
"fast-path must not rewrite the final file"
);
assert!(
!tokio::fs::try_exists(&part_path).await.unwrap_or(false),
"fast-path must remove parts after marking finished"
);
let bytes = fs::read(instruction.metadata_path()).await?;
let persisted = {
use prost::Message;
crate::download_metadata::DownloadMetadata::decode_length_delimited(&*bytes)
.expect("decode metadata")
};
assert!(persisted.finished);
let _ = pre_meta; Ok(())
}
#[tokio::test]
async fn test_recovery_fast_path_rejects_corrupt_existing_final()
-> Result<(), Box<dyn std::error::Error>> {
use base64::Engine;
use sha2::{Digest, Sha256};
let file_content = b"RejectCorruptExistingFinal-MustReassembleFromParts";
let mut hasher = Sha256::new();
hasher.update(file_content);
let sha256_b64 = base64::engine::general_purpose::STANDARD.encode(hasher.finalize());
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let download_dir = tmp_data_dir.path().join("payload-rejectfinal");
fs::create_dir_all(&download_dir).await?;
let part_ulid = "rejectfinal_part";
let mut parts = HashMap::new();
parts.insert(
part_ulid.to_string(),
PartDetails {
ulid: part_ulid.to_string(),
offset: 0,
size: file_content.len() as u64,
finished: true,
},
);
let checksums = vec![crate::hash::HashDigest::SHA256(
sha256_b64,
crate::hash::HashEncoding::Base64,
)];
let instruction = DownloadBuilder::default()
.download_dir(download_dir.clone())
.save_dir(tmp_save_dir.path().to_path_buf())
.filename("payload-rejectfinal.bin".to_string())
.url(Url::parse("http://example.invalid/payload-rejectfinal.bin").unwrap())
.size(Some(file_content.len() as u64))
.max_connections(1)
.parts(parts)
.is_resumable(true)
.checksums(checksums)
.build()
.unwrap();
fs::write(instruction.part_path(part_ulid), file_content).await?;
fs::write(instruction.final_file_path(), vec![0u8; file_content.len()]).await?;
let metadata = instruction.as_metadata();
persist_metadata(&metadata, &instruction).await?;
let cfg = test_cfg(tmp_data_dir.path(), 1);
let dlm = DownloadManager::new(cfg);
let resolver = AlwaysAbortResolver {};
let final_path = dlm
.download(DownloadRequest::new(instruction.clone(), &resolver))
.await?;
let on_disk = fs::read(&final_path).await?;
assert_eq!(
on_disk, file_content,
"must re-assemble: existing final was zero-padded carcass"
);
Ok(())
}
#[tokio::test]
async fn test_resumes_assembly_when_final_file_absent() -> Result<(), Box<dyn std::error::Error>>
{
let file_content = b"AnotherResumePayload-FinalFileAbsent";
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let download_dir = tmp_data_dir.path().join("payload-bin-2");
fs::create_dir_all(&download_dir).await?;
let only_ulid = "only_part_resume2";
let mut parts = HashMap::new();
parts.insert(
only_ulid.to_string(),
PartDetails {
ulid: only_ulid.to_string(),
offset: 0,
size: file_content.len() as u64,
finished: true,
},
);
let instruction = DownloadBuilder::default()
.download_dir(download_dir.clone())
.save_dir(tmp_save_dir.path().to_path_buf())
.filename("payload2.bin".to_string())
.url(Url::parse("http://example.invalid/payload2.bin").unwrap())
.size(Some(file_content.len() as u64))
.max_connections(1)
.parts(parts)
.is_resumable(true)
.build()
.unwrap();
fs::write(instruction.part_path(only_ulid), file_content).await?;
let metadata = instruction.as_metadata();
persist_metadata(&metadata, &instruction).await?;
assert!(
!tokio::fs::try_exists(instruction.final_file_path())
.await
.unwrap_or(false)
);
let cfg = test_cfg(tmp_data_dir.path(), 1);
let dlm = DownloadManager::new(cfg);
let resolver = AlwaysAbortResolver {};
let final_path = dlm
.download(DownloadRequest::new(instruction.clone(), &resolver))
.await?;
let on_disk = fs::read(&final_path).await?;
assert_eq!(on_disk, file_content);
Ok(())
}
}