mod checksum;
mod downloader;
mod io;
mod recover_metadata;
mod save_conflict;
mod server_conflict;
use std::{path::PathBuf, sync::Arc};
use fs2::FileExt;
use reqwest::{
Client, Proxy, Url,
header::{HeaderMap, USER_AGENT},
};
use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore};
use tracing::Span;
use tracing_indicatif::span_ext::IndicatifSpanExt;
use crate::config::Config;
use crate::download_manager::recover_metadata::recover_metadata;
use crate::download_manager::{downloader::Downloader, io::persist_metadata};
use crate::download_manager::{io::assemble_final_file, server_conflict::resolve_server_conflicts};
use crate::download_manager::{io::remove_all_parts, save_conflict::resolve_save_conflicts};
use crate::error::MetadataError;
use crate::response_info::ResponseInfo;
use crate::retry_policies::{FixedThenExponentialRetry, wait_for_retry};
use crate::{
conflict::{SaveConflictResolver, ServerConflictResolver},
download_manager::io::sum_parts_on_disk,
};
use crate::{credentials::Credentials, user_agents::random_user_agent};
use crate::{download::Download, download_metadata::PartDetails, error::OdlError};
#[derive(Debug)]
pub struct DownloadManager {
config: Config,
semaphore: Arc<Semaphore>,
}
impl DownloadManager {
pub fn new(config: Config) -> DownloadManager {
let max_concurrent_downloads = config.max_concurrent_downloads;
DownloadManager {
config,
semaphore: Arc::new(Semaphore::new(max_concurrent_downloads)),
}
}
pub fn config(&self) -> &Config {
&self.config
}
pub async fn set_config(&mut self, value: Config) -> Result<(), AcquireError> {
let old_max = self.config.max_concurrent_downloads;
self.config = value;
let new_max = self.config.max_concurrent_downloads;
if new_max > old_max {
let add_count = new_max.saturating_sub(old_max);
self.semaphore.add_permits(add_count);
} else if new_max < old_max {
let forget_count = old_max.saturating_sub(new_max);
let _perm = Arc::clone(&self.semaphore)
.acquire_many_owned(forget_count as u32)
.await?;
_perm.forget();
}
Ok(())
}
pub async fn evaluate<CR>(
&self,
url: Url,
save_dir: PathBuf,
credentials: Option<Credentials>,
conflict_resolver: &CR,
) -> Result<Download, OdlError>
where
CR: SaveConflictResolver,
{
let current_span = Span::current();
current_span.pb_set_message("Evaluating");
let client = self.get_client(None)?;
let retry_policy = FixedThenExponentialRetry {
max_n_retries: self.config.max_retries,
wait_time: self.config.wait_between_retries,
n_fixed_retries: self.config.n_fixed_retries,
};
let mut attempts: u32 = 0;
let resp = loop {
let mut req = client
.head(url.clone())
.header(
"Want-Repr-Digest",
"sha-512=9, sha-384=8, sha-256=7, sha-1=1, md5=1",
)
.header(
"Want-Content-Digest",
"sha-512=9, sha-384=8, sha-256=7, sha-1=1, md5=1",
);
if let Some(creds) = &credentials {
req = req.basic_auth(creds.username(), creds.password());
}
if self.config.user_agent.is_none() && self.config.randomize_user_agent {
req = req.header(USER_AGENT, random_user_agent());
}
match req.send().await.and_then(|r| r.error_for_status()) {
Ok(r) => break r,
Err(e) => {
attempts = attempts.saturating_add(1);
if !wait_for_retry(&retry_policy, attempts, ¤t_span).await {
return Err(OdlError::from(e));
}
current_span.pb_set_message("Evaluating");
}
}
};
let info = ResponseInfo::from(resp);
let instruction = Download::from_response_info(
&self.config.download_dir,
save_dir,
info,
self.config.max_connections,
self.config.use_server_time,
credentials,
Option::<Proxy>::from(&self.config),
Some(HeaderMap::from(&self.config)),
);
let instruction = resolve_save_conflicts(instruction, conflict_resolver).await?;
current_span.pb_set_message(instruction.filename());
if let Some(size) = instruction.size() {
current_span.pb_set_length(size);
} else {
current_span.pb_set_length(0);
}
Ok(instruction)
}
pub async fn download<CR>(
&self,
instruction: Download,
conflict_resolver: &CR,
) -> Result<PathBuf, OdlError>
where
CR: ServerConflictResolver,
{
tokio::fs::create_dir_all(instruction.download_dir()).await?;
match tokio::fs::OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(true)
.open(instruction.lockfile_path())
.await
{
Ok(f) => {
let f = f.into_std().await;
if f.try_lock_exclusive().is_err() {
return Err(OdlError::MetadataError(MetadataError::LockfileInUse));
}
let result = self.process_download(instruction, conflict_resolver).await;
let _ = FileExt::unlock(&f);
result
}
Err(e) => Err(OdlError::StdIoError {
e,
extra_info: Some(format!(
"Failed to open lockfile for exclusive locking at {}",
instruction.lockfile_path().display(),
)),
}),
}
}
pub async fn acquire_download_permit(&self) -> Result<OwnedSemaphorePermit, AcquireError> {
Arc::clone(&self.semaphore).acquire_owned().await
}
fn get_client(&self, instructions: Option<&Download>) -> Result<Client, OdlError> {
let mut client = reqwest::Client::builder();
if let Some(download) = instructions {
if let Some(proxy) = download.proxy() {
client = client.proxy(proxy.clone());
}
if let Some(headers) = download.headers() {
client = client.default_headers(headers.clone());
}
} else {
if self.config.headers.as_ref().is_some_and(|x| !x.is_empty()) {
client = client.default_headers(HeaderMap::from(&self.config));
}
if let Some(proxy) = Option::<Proxy>::from(&self.config) {
client = client.proxy(proxy);
}
}
if self.config.accept_invalid_certs {
client = client.danger_accept_invalid_certs(self.config.accept_invalid_certs)
}
if let Some(user_agent) = &self.config.user_agent {
client = client.user_agent(user_agent.clone());
}
if let Some(timeout) = &self.config.connect_timeout {
client = client.connect_timeout(*timeout);
}
Ok(client.build()?)
}
async fn process_download<CR>(
&self,
instruction: Download,
conflict_resolver: &CR,
) -> Result<PathBuf, OdlError>
where
CR: ServerConflictResolver,
{
tokio::fs::create_dir_all(instruction.save_dir()).await?;
recover_metadata(&instruction).await?;
let mut metadata = resolve_server_conflicts(&instruction, conflict_resolver).await?;
if let Some(sum_of_parts_sizes) = sum_parts_on_disk(&instruction, &metadata).await {
let size: Option<u64> = metadata.size.or_else(|| instruction.size());
let current_span = Span::current();
current_span.pb_reset();
if let Some(size) = size {
current_span.pb_set_length(size);
}
current_span.pb_set_position(sum_of_parts_sizes);
current_span.pb_reset_eta();
}
if !metadata.finished {
let to_download = metadata
.parts
.iter()
.filter_map(|(_, p)| if !p.finished { Some(p.clone()) } else { None })
.collect::<Vec<PartDetails>>();
if !to_download.is_empty() {
let randomize_user_agent = if self.config.user_agent.is_some() {
false
} else {
self.config.randomize_user_agent
};
let client = self.get_client(Some(&instruction))?;
let retry_policy = crate::retry_policies::FixedThenExponentialRetry {
max_n_retries: self.config.max_retries,
wait_time: self.config.wait_between_retries,
n_fixed_retries: self.config.n_fixed_retries,
};
let downloader = Downloader::new(
Arc::new(instruction.clone()),
metadata,
client,
randomize_user_agent,
Span::current(),
self.config.speed_limit,
retry_policy,
);
let mut mdata = downloader.run().await?;
mdata.finished = true;
persist_metadata(&mdata, &instruction).await?;
metadata = mdata;
}
let final_path = assemble_final_file(&metadata, &instruction).await?;
remove_all_parts(instruction.download_dir()).await;
Ok(final_path)
} else {
let final_path = instruction.final_file_path();
if tokio::fs::try_exists(&final_path).await.unwrap_or(false) {
Ok(final_path)
} else {
Err(OdlError::StdIoError {
e: std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("Expected final file not found at {}", final_path.display()),
),
extra_info: None,
})
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::conflict::FileChangedResolution;
use crate::conflict::NotResumableResolution;
use crate::conflict::ServerConflict;
use crate::download::DownloadBuilder;
use crate::download_metadata::PartDetails;
use crate::error::ConflictError;
use async_trait::async_trait;
use mockito::Matcher;
use mockito::Server;
use std::collections::HashMap;
use tempfile::tempdir;
use tokio::fs;
use tokio::io::AsyncWriteExt;
struct AlwaysAbortResolver;
#[async_trait]
impl ServerConflictResolver for AlwaysAbortResolver {
async fn resolve_file_changed(&self, _: &Download) -> FileChangedResolution {
FileChangedResolution::Abort
}
async fn resolve_not_resumable(&self, _: &Download) -> NotResumableResolution {
NotResumableResolution::Abort
}
}
struct AlwaysReplaceResolver;
#[async_trait]
impl SaveConflictResolver for AlwaysReplaceResolver {
async fn final_file_exists(
&self,
_: &Download,
) -> crate::conflict::FinalFileExistsResolution {
crate::conflict::FinalFileExistsResolution::ReplaceAndContinue
}
async fn same_download_exists(
&self,
_: &Download,
) -> crate::conflict::SameDownloadExistsResolution {
crate::conflict::SameDownloadExistsResolution::Resume
}
}
#[tokio::test]
async fn test_download_manager_multipart_download() -> Result<(), Box<dyn std::error::Error>> {
let file_content = b"HelloWorldThisIsATestFile";
let part1 = &file_content[..10]; let part2 = &file_content[10..];
let mut server = Server::new_async().await;
let url = server.url();
let head_mock = server
.mock("HEAD", "/testfile")
.with_status(200)
.with_header("content-length", &file_content.len().to_string())
.with_header("accept-ranges", "bytes")
.with_header("etag", "testetag")
.with_header("last-modified", "Wed, 21 Oct 2015 07:28:00 GMT")
.create_async()
.await;
let get_mock1 = server
.mock("GET", "/testfile")
.match_header("range", Matcher::Exact("bytes=0-9".into()))
.with_status(206)
.with_body(part1)
.create_async()
.await;
let get_mock2 = server
.mock("GET", "/testfile")
.match_header(
"range",
Matcher::Exact(format!("bytes=10-{}", file_content.len() - 1)),
)
.with_status(206)
.with_body(part2)
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let cfg = crate::config::ConfigBuilder::default()
.download_dir(tmp_data_dir.path().to_path_buf())
.max_connections(2)
.build()
.unwrap();
let dlm = DownloadManager::new(cfg);
let save_resolver = AlwaysReplaceResolver {};
let instruction = dlm
.evaluate(
Url::parse(&format!("{}/testfile", url)).unwrap(),
tmp_save_dir.path().to_path_buf(),
None,
&save_resolver,
)
.await?;
let instruction = DownloadBuilder::default()
.download_dir(instruction.download_dir().clone())
.save_dir(instruction.save_dir().clone())
.filename(instruction.filename().to_string())
.url(instruction.url().clone())
.size(Some(file_content.len() as u64))
.max_connections(2)
.parts({
let mut parts = HashMap::new();
parts.insert(
"part1".to_string(),
PartDetails {
ulid: "part1".to_string(),
offset: 0,
size: 10,
finished: false,
},
);
parts.insert(
"part2".to_string(),
PartDetails {
ulid: "part2".to_string(),
offset: 10,
size: (file_content.len() - 10) as u64,
finished: false,
},
);
parts
})
.is_resumable(true)
.build()
.unwrap();
let resolver = AlwaysAbortResolver {};
let final_path = dlm.download(instruction, &resolver).await?;
let result = fs::read(&final_path).await?;
assert_eq!(result, file_content);
head_mock.assert_async().await;
get_mock1.assert_async().await;
get_mock2.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_save_conflict_final_file_exists_abort() -> Result<(), Box<dyn std::error::Error>>
{
let mut server = Server::new_async().await;
let base = server.url();
let head_mock = server
.mock("HEAD", "/file_abort")
.with_status(200)
.with_header("content-length", "1")
.with_header("accept-ranges", "bytes")
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let filename = "file_abort";
let final_path = tmp_save_dir.path().join(filename);
tokio::fs::write(&final_path, b"x").await?;
let cfg = crate::config::ConfigBuilder::default()
.download_dir(tmp_data_dir.path().to_path_buf())
.max_connections(1)
.build()
.unwrap();
let dlm = DownloadManager::new(cfg);
struct AbortFinalResolver;
#[async_trait]
impl SaveConflictResolver for AbortFinalResolver {
async fn final_file_exists(
&self,
_: &Download,
) -> crate::conflict::FinalFileExistsResolution {
crate::conflict::FinalFileExistsResolution::Abort
}
async fn same_download_exists(
&self,
_: &Download,
) -> crate::conflict::SameDownloadExistsResolution {
crate::conflict::SameDownloadExistsResolution::Resume
}
}
let resolver = AbortFinalResolver {};
let result = dlm
.evaluate(
Url::parse(&format!("{}/file_abort", base)).unwrap(),
tmp_save_dir.path().to_path_buf(),
None,
&resolver,
)
.await;
assert!(matches!(
result,
Err(OdlError::Conflict(ConflictError::Save {
conflict: crate::conflict::SaveConflict::FinalFileExists
}))
));
head_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_save_conflict_final_file_exists_add_number()
-> Result<(), Box<dyn std::error::Error>> {
let mut server = Server::new_async().await;
let base = server.url();
let head_mock = server
.mock("HEAD", "/file_add")
.with_status(200)
.with_header("content-length", "1")
.with_header("accept-ranges", "bytes")
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let filename = "file_add";
let final_path = tmp_save_dir.path().join(filename);
tokio::fs::write(&final_path, b"x").await?;
let cfg = crate::config::ConfigBuilder::default()
.download_dir(tmp_data_dir.path().to_path_buf())
.max_connections(1)
.build()
.unwrap();
let dlm = DownloadManager::new(cfg);
struct AddNumberResolver;
#[async_trait]
impl SaveConflictResolver for AddNumberResolver {
async fn final_file_exists(
&self,
_: &Download,
) -> crate::conflict::FinalFileExistsResolution {
crate::conflict::FinalFileExistsResolution::AddNumberToNameAndContinue
}
async fn same_download_exists(
&self,
_: &Download,
) -> crate::conflict::SameDownloadExistsResolution {
crate::conflict::SameDownloadExistsResolution::Resume
}
}
let resolver = AddNumberResolver {};
let instruction = dlm
.evaluate(
Url::parse(&format!("{}/file_add", base)).unwrap(),
tmp_save_dir.path().to_path_buf(),
None,
&resolver,
)
.await?;
assert_eq!(instruction.filename(), "file_add_2");
head_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_download_manager_single_part_download() -> Result<(), Box<dyn std::error::Error>>
{
let file_content = b"SinglePartFileContent";
let part = &file_content[..];
let mut server = Server::new_async().await;
let url = server.url();
let head_mock = server
.mock("HEAD", "/singlefile")
.with_status(200)
.with_header("content-length", &file_content.len().to_string())
.with_header("accept-ranges", "bytes")
.with_header("etag", "singleetag")
.with_header("last-modified", "Thu, 22 Oct 2015 07:28:00 GMT")
.create_async()
.await;
let get_mock = server
.mock("GET", "/singlefile")
.match_header(
"range",
Matcher::Exact(format!("bytes=0-{}", file_content.len() - 1)),
)
.with_status(206)
.with_body(part)
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let cfg = crate::config::ConfigBuilder::default()
.download_dir(tmp_data_dir.path().to_path_buf())
.max_connections(1)
.build()
.unwrap();
let dlm = DownloadManager::new(cfg);
let save_resolver = AlwaysReplaceResolver {};
let instruction = dlm
.evaluate(
Url::parse(&format!("{}/singlefile", url)).unwrap(),
tmp_save_dir.path().to_path_buf(),
None,
&save_resolver,
)
.await?;
let instruction = DownloadBuilder::default()
.download_dir(instruction.download_dir().clone())
.save_dir(instruction.save_dir().clone())
.filename(instruction.filename().to_string())
.url(instruction.url().clone())
.size(Some(file_content.len() as u64))
.max_connections(1)
.parts({
let mut parts = HashMap::new();
parts.insert(
"part1".to_string(),
PartDetails {
ulid: "part1".to_string(),
offset: 0,
size: file_content.len() as u64,
finished: false,
},
);
parts
})
.is_resumable(true)
.build()
.unwrap();
let resolver = AlwaysAbortResolver {};
let final_path = dlm.download(instruction, &resolver).await?;
let result = fs::read(&final_path).await?;
assert_eq!(result, file_content);
head_mock.assert_async().await;
get_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_download_manager_multipart_not_resumable_download()
-> Result<(), Box<dyn std::error::Error>> {
let file_content = b"NonResumableMultipartFile";
let mut server = Server::new_async().await;
let url = server.url();
let head_mock = server
.mock("HEAD", "/nonresumablefile")
.with_status(200)
.with_header("content-length", &file_content.len().to_string())
.with_header("etag", "nonresumableetag")
.with_header("last-modified", "Fri, 23 Oct 2015 07:28:00 GMT")
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let cfg = crate::config::ConfigBuilder::default()
.download_dir(tmp_data_dir.path().to_path_buf())
.max_connections(2)
.build()
.unwrap();
let dlm = DownloadManager::new(cfg);
let save_resolver = AlwaysReplaceResolver {};
let instruction = dlm
.evaluate(
Url::parse(&format!("{}/nonresumablefile", url)).unwrap(),
tmp_save_dir.path().to_path_buf(),
None,
&save_resolver,
)
.await?;
let instruction = DownloadBuilder::default()
.download_dir(instruction.download_dir().clone())
.save_dir(instruction.save_dir().clone())
.filename(instruction.filename().to_string())
.url(instruction.url().clone())
.size(Some(file_content.len() as u64))
.max_connections(2)
.parts({
let mut parts = HashMap::new();
parts.insert(
"part1".to_string(),
PartDetails {
ulid: "part1".to_string(),
offset: 0,
size: 10,
finished: false,
},
);
parts.insert(
"part2".to_string(),
PartDetails {
ulid: "part2".to_string(),
offset: 10,
size: (file_content.len() - 10) as u64,
finished: false,
},
);
parts
})
.is_resumable(false)
.build()
.unwrap();
struct AssertTestResolver;
#[async_trait]
impl ServerConflictResolver for AssertTestResolver {
async fn resolve_file_changed(&self, _: &Download) -> FileChangedResolution {
FileChangedResolution::Abort
}
async fn resolve_not_resumable(&self, _: &Download) -> NotResumableResolution {
NotResumableResolution::Abort
}
}
let resolver = AssertTestResolver {};
let result = dlm.download(instruction, &resolver).await;
assert!(matches!(
result,
Err(OdlError::Conflict(ConflictError::Server {
conflict: ServerConflict::NotResumable
}))
));
head_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_download_manager_multipart_not_resumable_restart_download()
-> Result<(), Box<dyn std::error::Error>> {
let file_content = b"NonResumableMultipartFile";
let part = &file_content[..];
let mut server = Server::new_async().await;
let url = server.url();
let head_mock = server
.mock("HEAD", "/nonresumablefile_restart")
.with_status(200)
.with_header("content-length", &file_content.len().to_string())
.with_header("etag", "nonresumableetag")
.with_header("last-modified", "Fri, 23 Oct 2015 07:28:00 GMT")
.create_async()
.await;
let get_mock = server
.mock("GET", "/nonresumablefile_restart")
.match_header(
"range",
Matcher::Exact(format!("bytes=0-{}", file_content.len() - 1)),
)
.with_status(206)
.with_body(part)
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let cfg = crate::config::ConfigBuilder::default()
.download_dir(tmp_data_dir.path().to_path_buf())
.max_connections(2)
.build()
.unwrap();
let dlm = DownloadManager::new(cfg);
let save_resolver = AlwaysReplaceResolver {};
let instruction = dlm
.evaluate(
Url::parse(&format!("{}/nonresumablefile_restart", url)).unwrap(),
tmp_save_dir.path().to_path_buf(),
None,
&save_resolver,
)
.await?;
let instruction = DownloadBuilder::default()
.download_dir(instruction.download_dir().clone())
.save_dir(instruction.save_dir().clone())
.filename(instruction.filename().to_string())
.url(instruction.url().clone())
.size(Some(file_content.len() as u64))
.max_connections(2)
.parts({
let mut parts = std::collections::HashMap::new();
parts.insert(
"part1".to_string(),
PartDetails {
ulid: "part1".to_string(),
offset: 0,
size: 10,
finished: false,
},
);
parts.insert(
"part2".to_string(),
PartDetails {
ulid: "part2".to_string(),
offset: 10,
size: (file_content.len() - 10) as u64,
finished: false,
},
);
parts
})
.is_resumable(false)
.build()
.unwrap();
struct AssertTestResolver;
#[async_trait]
impl ServerConflictResolver for AssertTestResolver {
async fn resolve_file_changed(&self, _: &Download) -> FileChangedResolution {
FileChangedResolution::Restart
}
async fn resolve_not_resumable(&self, _: &Download) -> NotResumableResolution {
NotResumableResolution::Restart
}
}
let resolver = AssertTestResolver {};
let final_path = dlm.download(instruction, &resolver).await?;
let result = fs::read(&final_path).await?;
assert_eq!(result, file_content);
head_mock.assert_async().await;
get_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_download_manager_zero_byte_single_part_download()
-> Result<(), Box<dyn std::error::Error>> {
let file_content = b"";
let mut server = Server::new_async().await;
let url = server.url();
let head_mock = server
.mock("HEAD", "/zerofile")
.with_status(200)
.with_header("content-length", "0")
.with_header("accept-ranges", "bytes")
.with_header("etag", "zeroetag")
.with_header("last-modified", "Sat, 24 Oct 2015 07:28:00 GMT")
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let cfg = crate::config::ConfigBuilder::default()
.download_dir(tmp_data_dir.path().to_path_buf())
.max_connections(1)
.build()
.unwrap();
let dlm = DownloadManager::new(cfg);
let save_resolver = AlwaysReplaceResolver {};
let instruction = dlm
.evaluate(
Url::parse(&format!("{}/zerofile", url)).unwrap(),
tmp_save_dir.path().to_path_buf(),
None,
&save_resolver,
)
.await?;
let instruction = DownloadBuilder::default()
.download_dir(instruction.download_dir().clone())
.save_dir(instruction.save_dir().clone())
.filename(instruction.filename().to_string())
.url(instruction.url().clone())
.size(Some(0))
.max_connections(1)
.parts({
let mut parts = std::collections::HashMap::new();
parts.insert(
"part1".to_string(),
PartDetails {
ulid: "part1".to_string(),
offset: 0,
size: 0,
finished: false,
},
);
parts
})
.is_resumable(true)
.build()
.unwrap();
let resolver = AlwaysAbortResolver {};
let final_path = dlm.download(instruction, &resolver).await?;
let result = fs::read(&final_path).await?;
assert_eq!(result, file_content);
head_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_download_part_resumes_with_correct_range()
-> Result<(), Box<dyn std::error::Error>> {
let file_content = b"PartialDownloadTestFile";
let part_offset = 0;
let part_size = file_content.len() as u64;
let already_downloaded = 7;
let mut server = Server::new_async().await;
let url = server.url();
let expected_range = format!(
"bytes={}-{}",
part_offset + already_downloaded,
part_offset + part_size - 1
);
let get_mock = server
.mock("GET", "/partialfile")
.match_header("range", Matcher::Exact(expected_range.clone()))
.with_status(206)
.with_body(&file_content[already_downloaded as usize..])
.create_async()
.await;
let tmp_dir = tempdir()?;
let download_dir = tmp_dir.path().join("partial");
fs::create_dir_all(&download_dir).await?;
let part_path = download_dir.join("part1.part");
{
let mut f = fs::File::create(&part_path).await?;
f.write_all(&file_content[..already_downloaded as usize])
.await?;
}
let part_details = PartDetails {
ulid: "part1".to_string(),
offset: part_offset,
size: part_size,
finished: false,
};
let mut parts_map = HashMap::new();
parts_map.insert(part_details.ulid.clone(), part_details.clone());
let instruction = DownloadBuilder::default()
.download_dir(download_dir.clone())
.save_dir(tmp_dir.path().to_path_buf())
.filename("partialfile.bin".to_string())
.url(Url::parse(&format!("{}/partialfile", url)).unwrap())
.is_resumable(true)
.max_connections(1)
.size(Some(part_size))
.parts(parts_map)
.build()
.unwrap();
let metadata = instruction.as_metadata();
let client = reqwest::Client::builder().build()?;
let retry_policy = crate::retry_policies::FixedThenExponentialRetry {
max_n_retries: 6,
wait_time: std::time::Duration::from_millis(100),
n_fixed_retries: 3,
};
let downloader = Downloader::new(
Arc::new(instruction.clone()),
metadata,
client,
false,
Span::current(),
None,
retry_policy,
);
let updated_metadata = downloader.run().await?;
assert!(
updated_metadata
.parts
.get("part1")
.map(|p| p.finished)
.unwrap_or(false)
);
let result = tokio::fs::read(&part_path).await?;
assert_eq!(result, file_content);
get_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_download_manager_custom_user_agent() -> Result<(), Box<dyn std::error::Error>> {
let file_content = b"UserAgentTestFile";
let mut server = Server::new_async().await;
let url = server.url();
let custom_ua = "MyCustomUserAgent/1.0";
let head_mock = server
.mock("HEAD", "/useragentfile")
.match_header("user-agent", Matcher::Exact(custom_ua.into()))
.with_status(200)
.with_header("content-length", &file_content.len().to_string())
.with_header("accept-ranges", "bytes")
.with_header("etag", "uaetag")
.with_header("last-modified", "Sun, 25 Oct 2015 07:28:00 GMT")
.create_async()
.await;
let get_mock = server
.mock("GET", "/useragentfile")
.match_header("user-agent", Matcher::Exact(custom_ua.into()))
.match_header(
"range",
Matcher::Exact(format!("bytes=0-{}", file_content.len() - 1)),
)
.with_status(206)
.with_body(file_content)
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let cfg = crate::config::ConfigBuilder::default()
.download_dir(tmp_data_dir.path().to_path_buf())
.max_connections(1)
.user_agent(Some(custom_ua.to_string()))
.randomize_user_agent(false)
.build()
.unwrap();
let dlm = DownloadManager::new(cfg);
let save_resolver = AlwaysReplaceResolver {};
let instruction = dlm
.evaluate(
Url::parse(&format!("{}/useragentfile", url)).unwrap(),
tmp_save_dir.path().to_path_buf(),
None,
&save_resolver,
)
.await?;
let instruction = DownloadBuilder::default()
.download_dir(instruction.download_dir().clone())
.save_dir(instruction.save_dir().clone())
.filename(instruction.filename().to_string())
.url(instruction.url().clone())
.size(Some(file_content.len() as u64))
.max_connections(1)
.parts({
let mut parts = std::collections::HashMap::new();
parts.insert(
"part1".to_string(),
PartDetails {
ulid: "part1".to_string(),
offset: 0,
size: file_content.len() as u64,
finished: false,
},
);
parts
})
.is_resumable(true)
.build()
.unwrap();
let resolver = AlwaysAbortResolver {};
let final_path = dlm.download(instruction, &resolver).await?;
let result = tokio::fs::read(&final_path).await?;
assert_eq!(result, file_content);
head_mock.assert_async().await;
get_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_e2e_evaluate_download_assemble_with_checksum()
-> Result<(), Box<dyn std::error::Error>> {
use base64::Engine;
use sha2::{Digest, Sha256};
let file_content = b"E2E full pipeline payload: evaluate -> download -> assemble -> verify";
let mut hasher = Sha256::new();
hasher.update(file_content);
let sha256_b64 = base64::engine::general_purpose::STANDARD.encode(hasher.finalize());
let repr_digest_value = format!("sha-256=:{}:", sha256_b64);
let mut server = Server::new_async().await;
let url = server.url();
let head_mock = server
.mock("HEAD", "/payload.bin")
.with_status(200)
.with_header("content-length", &file_content.len().to_string())
.with_header("accept-ranges", "bytes")
.with_header("etag", "e2eetag")
.with_header("last-modified", "Tue, 27 Oct 2015 07:28:00 GMT")
.with_header("Repr-Digest", &repr_digest_value)
.create_async()
.await;
let get_mock = server
.mock("GET", "/payload.bin")
.match_header(
"range",
Matcher::Exact(format!("bytes=0-{}", file_content.len() - 1)),
)
.with_status(206)
.with_body(file_content)
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let cfg = crate::config::ConfigBuilder::default()
.download_dir(tmp_data_dir.path().to_path_buf())
.max_connections(1)
.build()
.unwrap();
let dlm = DownloadManager::new(cfg);
let save_resolver = AlwaysReplaceResolver {};
let instruction = dlm
.evaluate(
Url::parse(&format!("{}/payload.bin", url)).unwrap(),
tmp_save_dir.path().to_path_buf(),
None,
&save_resolver,
)
.await?;
assert!(
!instruction.as_metadata().checksums.is_empty(),
"evaluate did not extract checksum from Repr-Digest"
);
assert_eq!(instruction.size(), Some(file_content.len() as u64));
let resolver = AlwaysAbortResolver {};
let final_path = dlm.download(instruction, &resolver).await?;
let on_disk = fs::read(&final_path).await?;
assert_eq!(on_disk, file_content, "final file content mismatch");
let mut hasher = Sha256::new();
hasher.update(&on_disk);
let actual_b64 = base64::engine::general_purpose::STANDARD.encode(hasher.finalize());
assert_eq!(actual_b64, sha256_b64);
head_mock.assert_async().await;
get_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_e2e_download_fails_on_checksum_mismatch() -> Result<(), Box<dyn std::error::Error>>
{
let file_content = b"payload-served-by-server";
let bogus_repr_digest = "sha-256=:AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=:";
let mut server = Server::new_async().await;
let url = server.url();
let head_mock = server
.mock("HEAD", "/bad.bin")
.with_status(200)
.with_header("content-length", &file_content.len().to_string())
.with_header("accept-ranges", "bytes")
.with_header("Repr-Digest", bogus_repr_digest)
.create_async()
.await;
let get_mock = server
.mock("GET", "/bad.bin")
.match_header(
"range",
Matcher::Exact(format!("bytes=0-{}", file_content.len() - 1)),
)
.with_status(206)
.with_body(file_content)
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let cfg = crate::config::ConfigBuilder::default()
.download_dir(tmp_data_dir.path().to_path_buf())
.max_connections(1)
.build()
.unwrap();
let dlm = DownloadManager::new(cfg);
let save_resolver = AlwaysReplaceResolver {};
let instruction = dlm
.evaluate(
Url::parse(&format!("{}/bad.bin", url)).unwrap(),
tmp_save_dir.path().to_path_buf(),
None,
&save_resolver,
)
.await?;
let resolver = AlwaysAbortResolver {};
let result = dlm.download(instruction, &resolver).await;
assert!(
matches!(
result,
Err(OdlError::Conflict(ConflictError::ChecksumMismatch { .. }))
),
"expected ChecksumMismatch, got {:?}",
result
);
head_mock.assert_async().await;
get_mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_e2e_multipart_evaluate_download_assemble_with_checksum()
-> Result<(), Box<dyn std::error::Error>> {
use base64::Engine;
use sha2::{Digest, Sha256};
use rand::{RngCore, SeedableRng, rngs::StdRng};
let size: usize = 900 * 1024;
let mut rng = StdRng::seed_from_u64(0x00C0_FFEE_F00D);
let mut file_content = vec![0u8; size];
rng.fill_bytes(&mut file_content);
let mut hasher = Sha256::new();
hasher.update(&file_content);
let sha256_b64 = base64::engine::general_purpose::STANDARD.encode(hasher.finalize());
let repr_digest_value = format!("sha-256=:{}:", sha256_b64);
let mut server = Server::new_async().await;
let url = server.url();
let head_mock = server
.mock("HEAD", "/big.bin")
.with_status(200)
.with_header("content-length", &size.to_string())
.with_header("accept-ranges", "bytes")
.with_header("etag", "bigetag")
.with_header("Repr-Digest", &repr_digest_value)
.create_async()
.await;
let part_size = size / 3;
let mut get_mocks = Vec::new();
for i in 0..3 {
let start = i * part_size;
let end = if i == 2 {
size - 1
} else {
start + part_size - 1
};
let body = file_content[start..=end].to_vec();
let m = server
.mock("GET", "/big.bin")
.match_header("range", Matcher::Exact(format!("bytes={}-{}", start, end)))
.with_status(206)
.with_body(body)
.create_async()
.await;
get_mocks.push(m);
}
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let cfg = crate::config::ConfigBuilder::default()
.download_dir(tmp_data_dir.path().to_path_buf())
.max_connections(3)
.build()
.unwrap();
let dlm = DownloadManager::new(cfg);
let save_resolver = AlwaysReplaceResolver {};
let instruction = dlm
.evaluate(
Url::parse(&format!("{}/big.bin", url)).unwrap(),
tmp_save_dir.path().to_path_buf(),
None,
&save_resolver,
)
.await?;
let metadata = instruction.as_metadata();
assert_eq!(
metadata.parts.len(),
3,
"expected 3 parts, got {}",
metadata.parts.len()
);
let mut offsets: Vec<(u64, u64)> = metadata
.parts
.values()
.map(|p| (p.offset, p.size))
.collect();
offsets.sort_by_key(|(o, _)| *o);
let mut covered: u64 = 0;
for (off, sz) in &offsets {
assert_eq!(*off, covered);
covered += sz;
}
assert_eq!(covered, size as u64);
assert!(!metadata.checksums.is_empty());
let resolver = AlwaysAbortResolver {};
let final_path = dlm.download(instruction, &resolver).await?;
let on_disk = fs::read(&final_path).await?;
assert_eq!(on_disk.len(), file_content.len());
assert_eq!(on_disk, file_content, "assembled file bytes mismatch");
let mut hasher = Sha256::new();
hasher.update(&on_disk);
let actual_b64 = base64::engine::general_purpose::STANDARD.encode(hasher.finalize());
assert_eq!(actual_b64, sha256_b64);
head_mock.assert_async().await;
for m in &get_mocks {
m.assert_async().await;
}
Ok(())
}
#[tokio::test]
async fn test_download_manager_random_user_agent() -> Result<(), Box<dyn std::error::Error>> {
let file_content = b"RandomUserAgentTestFile";
let mut server = Server::new_async().await;
let url = server.url();
let head_mock = server
.mock("HEAD", "/randomua")
.match_header("user-agent", Matcher::Any)
.with_status(200)
.with_header("content-length", &file_content.len().to_string())
.with_header("accept-ranges", "bytes")
.with_header("etag", "randomuaetag")
.with_header("last-modified", "Mon, 26 Oct 2015 07:28:00 GMT")
.create_async()
.await;
let get_mock = server
.mock("GET", "/randomua")
.match_header("user-agent", Matcher::Any)
.match_header(
"range",
Matcher::Exact(format!("bytes=0-{}", file_content.len() - 1)),
)
.with_status(206)
.with_body(file_content)
.create_async()
.await;
let tmp_data_dir = tempfile::tempdir()?;
let tmp_save_dir = tempfile::tempdir()?;
let cfg = crate::config::ConfigBuilder::default()
.download_dir(tmp_data_dir.path().to_path_buf())
.max_connections(1)
.randomize_user_agent(true)
.build()
.unwrap();
let dlm = DownloadManager::new(cfg);
let save_resolver = AlwaysReplaceResolver {};
let instruction = dlm
.evaluate(
Url::parse(&format!("{}/randomua", url)).unwrap(),
tmp_save_dir.path().to_path_buf(),
None,
&save_resolver,
)
.await?;
let instruction = DownloadBuilder::default()
.download_dir(instruction.download_dir().clone())
.save_dir(instruction.save_dir().clone())
.filename(instruction.filename().to_string())
.url(instruction.url().clone())
.size(Some(file_content.len() as u64))
.max_connections(1)
.parts({
let mut parts = std::collections::HashMap::new();
parts.insert(
"part1".to_string(),
PartDetails {
ulid: "part1".to_string(),
offset: 0,
size: file_content.len() as u64,
finished: false,
},
);
parts
})
.is_resumable(true)
.build()
.unwrap();
let resolver = AlwaysAbortResolver {};
let final_path = dlm.download(instruction, &resolver).await?;
let result = tokio::fs::read(&final_path).await?;
assert_eq!(result, file_content);
head_mock.assert_async().await;
get_mock.assert_async().await;
Ok(())
}
}