use crate::{
credentials::Credentials,
download_metadata::{DownloadMetadata, FileChecksum, PartDetails},
error::MetadataError,
fs_utils,
hash::HashDigest,
response_info::ResponseInfo,
};
use chrono::{DateTime, Utc};
use derive_builder::{Builder, UninitializedFieldError};
use reqwest::{
Proxy, Url,
header::{HeaderMap, HeaderName, HeaderValue},
};
use std::{
collections::HashMap,
path::{self, PathBuf},
str::FromStr,
};
use thiserror::Error;
use tokio::sync::Semaphore;
use ulid::Ulid;
#[derive(Builder, Debug, Clone)]
#[builder(build_fn(validate = "Self::validate", error = "DownloadBuilderError"))]
pub struct Download {
download_dir: path::PathBuf,
url: Url,
#[builder(default = false)]
is_resumable: bool,
#[builder(default = false)]
use_server_time: bool,
filename: String,
save_dir: path::PathBuf,
#[builder(default = None)]
size: Option<u64>,
#[builder(default = Vec::new())]
checksums: Vec<HashDigest>,
#[builder(default = None)]
etag: Option<String>,
#[builder(default = None)]
last_modified: Option<i64>,
#[builder(default = false)]
requires_auth: bool,
#[builder(default = false)]
requires_basic_auth: bool,
#[builder(default = None)]
credentials: Option<Credentials>,
#[builder(default = None)]
proxy: Option<Proxy>,
#[builder(default = None)]
headers: Option<HeaderMap>,
#[builder(default = 6)]
max_connections: u64,
parts: HashMap<String, PartDetails>,
#[builder(default = false)]
finished: bool,
}
impl Download {
const METADATA_FILENAME: &'static str = "metadata.pb";
const METADATA_TEMP_FILENAME: &'static str = "metadata.pb.temp";
const LOCK_FILENAME: &'static str = "odl.lock";
pub const PART_EXTENSION: &'static str = "part";
const MIN_PART_SIZE: u64 = 300 * 1024; pub const ASSEMBLY_CLUSTER_SIZE: u64 = 4096;
const _ASSERT_CLUSTER_POW2: () =
assert!(Self::ASSEMBLY_CLUSTER_SIZE.is_power_of_two());
const _ASSERT_MIN_PART_GE_CLUSTER: () =
assert!(Self::MIN_PART_SIZE >= Self::ASSEMBLY_CLUSTER_SIZE);
pub fn download_dir(&self) -> &path::PathBuf {
&self.download_dir
}
pub fn part_path(&self, ulid: &str) -> path::PathBuf {
self.download_dir
.join(format!("{}.{}", ulid, Self::PART_EXTENSION))
}
pub fn set_download_dir(&mut self, path: PathBuf) {
self.download_dir = path
}
pub fn lockfile_path(&self) -> path::PathBuf {
self.download_dir.join(Self::LOCK_FILENAME)
}
pub fn metadata_path(&self) -> path::PathBuf {
self.download_dir.join(Self::METADATA_FILENAME)
}
pub fn metadata_temp_path(&self) -> path::PathBuf {
self.download_dir.join(Self::METADATA_TEMP_FILENAME)
}
pub fn final_file_path(&self) -> path::PathBuf {
self.save_dir.join(&self.filename)
}
pub fn url(&self) -> &Url {
&self.url
}
pub fn is_resumable(&self) -> bool {
self.is_resumable
}
pub fn use_server_time(&self) -> bool {
self.use_server_time
}
pub fn filename(&self) -> &str {
&self.filename
}
pub fn set_filename(&mut self, filename: String) {
self.filename = filename;
}
pub fn save_dir(&self) -> &path::PathBuf {
&self.save_dir
}
pub fn set_save_dir(&mut self, path: PathBuf) {
self.save_dir = path
}
pub fn size(&self) -> Option<u64> {
self.size
}
pub fn etag(&self) -> &Option<String> {
&self.etag
}
pub fn last_modified(&self) -> Option<i64> {
self.last_modified
}
pub fn last_modified_as_date(&self) -> Option<DateTime<Utc>> {
self.last_modified
.and_then(|x| chrono::DateTime::from_timestamp(x, 0))
}
pub fn requires_auth(&self) -> bool {
self.requires_auth
}
pub fn requires_basic_auth(&self) -> bool {
self.requires_basic_auth
}
pub fn credentials(&self) -> &Option<Credentials> {
&self.credentials
}
pub fn proxy(&self) -> &Option<Proxy> {
&self.proxy
}
pub fn headers(&self) -> &Option<HeaderMap> {
&self.headers
}
pub fn max_connections(&self) -> u64 {
self.max_connections
}
pub fn parts(&self) -> &HashMap<String, PartDetails> {
&self.parts
}
pub fn finished(&self) -> bool {
self.finished
}
pub fn from_metadata(
download_dir: path::PathBuf,
metadata: DownloadMetadata,
) -> Result<Download, MetadataError> {
let url = Url::parse(&metadata.url).map_err(|e| MetadataError::Other {
message: e.to_string(),
})?;
Ok(Self {
download_dir,
url,
is_resumable: metadata.is_resumable,
use_server_time: metadata.use_server_time,
filename: metadata.filename, save_dir: PathBuf::from(metadata.save_dir),
etag: metadata.last_etag,
last_modified: metadata.last_modified,
size: metadata.size,
checksums: metadata
.checksums
.into_iter()
.map(|c| c.try_into())
.collect::<Result<Vec<HashDigest>, _>>()
.unwrap_or_default(),
credentials: None,
requires_auth: metadata.requires_auth,
requires_basic_auth: metadata.requires_basic_auth,
proxy: None,
headers: if metadata.headers.is_empty() {
None
} else {
let mut map = HeaderMap::new();
for (k, v) in metadata.headers {
if let (Ok(header_name), Ok(header_value)) =
(HeaderName::from_str(&k), HeaderValue::from_str(&v))
{
map.insert(header_name, header_value);
}
}
Some(map)
},
max_connections: metadata.max_connections,
parts: metadata.parts,
finished: metadata.finished,
})
}
pub fn as_metadata(&self) -> DownloadMetadata {
DownloadMetadata {
url: self.url.to_string(),
filename: self.filename.clone(),
save_dir: self.save_dir.to_string_lossy().into_owned(),
is_resumable: self.is_resumable,
use_server_time: self.use_server_time,
last_modified: self.last_modified,
last_etag: self.etag.clone(),
size: self.size,
checksums: self
.checksums
.iter()
.map(|h| h.clone().into())
.collect::<Vec<FileChecksum>>(),
requires_auth: self.requires_auth,
requires_basic_auth: self.requires_basic_auth,
headers: self
.headers
.as_ref()
.map(|h| {
h.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect()
})
.unwrap_or_default(),
max_connections: self.max_connections,
parts: self.parts.clone(),
finished: self.finished,
}
}
#[allow(clippy::too_many_arguments)]
pub fn from_response_info(
download_dir: &std::path::Path,
save_dir: path::PathBuf,
response_info: ResponseInfo,
max_connections: u64,
use_server_time: bool,
credentials: Option<Credentials>,
proxy: Option<Proxy>,
headers: Option<HeaderMap>,
) -> Download {
let filename = fs_utils::cleanup_filename(response_info.extract_filename().as_str());
Self {
download_dir: download_dir.join(&filename),
url: response_info.url().clone(),
is_resumable: response_info.is_resumable(),
use_server_time,
filename,
save_dir,
etag: response_info.etag(),
last_modified: response_info.parse_last_modified(),
size: response_info.total_length(),
checksums: response_info.extract_hashes(),
credentials,
requires_auth: response_info.requires_auth(),
requires_basic_auth: response_info.requires_basic_auth(),
proxy,
headers,
max_connections,
parts: Download::determine_parts(
response_info.total_length(),
if response_info.is_resumable() {
max_connections
} else {
1
},
),
finished: false,
}
}
pub fn determine_parts(
size: Option<u64>,
max_connections: u64,
) -> HashMap<String, PartDetails> {
let mut parts = HashMap::new();
let max_connections = if max_connections > 0 {
max_connections
} else {
1
};
let size = size.unwrap_or(0);
if size <= Self::MIN_PART_SIZE {
let ulid = Ulid::new().to_string();
parts.insert(
ulid.clone(),
PartDetails {
offset: 0,
size,
ulid,
finished: size == 0,
},
);
return parts;
}
let mut actual_connections = max_connections;
let min_connections = size.div_ceil(Self::MIN_PART_SIZE);
if actual_connections > min_connections {
actual_connections = min_connections;
}
let mask = Self::ASSEMBLY_CLUSTER_SIZE - 1;
let base_size = (size / actual_connections) & !mask;
let mut offset = 0;
for i in 0..actual_connections {
let part_size = if i == actual_connections - 1 {
size - offset
} else {
base_size
};
let ulid = Ulid::new().to_string();
parts.insert(
ulid.clone(),
PartDetails {
offset,
size: part_size,
ulid,
finished: false,
},
);
offset += part_size;
}
parts
}
}
impl PartialEq for Download {
fn eq(&self, other: &Self) -> bool {
self.url == other.url
&& self.download_dir == other.download_dir
&& self.filename == other.filename
}
}
impl DownloadBuilder {
fn validate(&self) -> Result<(), DownloadBuilderError> {
if self.download_dir.is_none() {
return Err(DownloadBuilderError::MissingDownloadDir);
}
if self.save_dir.is_none() {
return Err(DownloadBuilderError::MissingSaveDir);
}
if self.url.is_none() {
return Err(DownloadBuilderError::MissingUrl);
}
if self.filename.is_none() {
return Err(DownloadBuilderError::MissingFilename);
}
if self
.max_connections
.is_none_or(|x| x == 0 || x >= Semaphore::MAX_PERMITS.try_into().unwrap_or(1_000_000))
{
return Err(DownloadBuilderError::InvalidNumConnections);
}
Ok(())
}
}
#[derive(Error, Debug)]
pub enum DownloadBuilderError {
#[error("download_dir is required")]
MissingDownloadDir,
#[error("save_dir is required")]
MissingSaveDir,
#[error("url is required")]
MissingUrl,
#[error("filename is required")]
MissingFilename,
#[error("max_connections must be at least 1")]
InvalidNumConnections,
#[error("uninitialized field: {0}")]
UninitializedField(String),
#[error("validation error: {0}")]
ValidationError(String),
}
impl From<String> for DownloadBuilderError {
fn from(s: String) -> Self {
Self::ValidationError(s)
}
}
impl From<UninitializedFieldError> for DownloadBuilderError {
fn from(ufe: UninitializedFieldError) -> Self {
Self::UninitializedField(ufe.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_determine_parts_zero_size() {
let parts = Download::determine_parts(Some(0), 4);
assert_eq!(parts.len(), 1);
let part_vec: Vec<_> = parts.values().collect();
let part = part_vec[0];
assert_eq!(part.offset, 0);
assert_eq!(part.size, 0);
assert!(part.finished);
}
#[test]
fn test_determine_parts_zero_connections() {
let parts = Download::determine_parts(Some(1024 * 1024), 0);
assert_eq!(parts.len(), 1);
let part_vec: Vec<_> = parts.values().collect();
let part = part_vec[0];
assert_eq!(part.offset, 0);
assert_eq!(part.size, 1024 * 1024);
assert!(!part.finished);
}
#[test]
fn test_determine_parts_small_file() {
let size = 200 * 1024;
let parts = Download::determine_parts(Some(size), 4);
assert_eq!(parts.len(), 1);
let part_vec: Vec<_> = parts.values().collect();
let part = part_vec[0];
assert_eq!(part.offset, 0);
assert_eq!(part.size, size);
assert!(!part.finished);
}
#[test]
fn test_determine_parts_exact_min_part_size() {
let size = 300 * 1024;
let parts = Download::determine_parts(Some(size), 4);
assert_eq!(parts.len(), 1);
let part_vec: Vec<_> = parts.values().collect();
let part = part_vec[0];
assert_eq!(part.offset, 0);
assert_eq!(part.size, size);
assert!(!part.finished);
}
#[test]
fn test_determine_parts_even_split() {
let size = 1024 * 1024;
let max_connections = 4;
let parts = Download::determine_parts(Some(size), max_connections);
assert_eq!(parts.len(), max_connections as usize);
let mut part_vec: Vec<_> = parts.values().collect();
part_vec.sort_by_key(|p| p.offset);
let total: u64 = part_vec.iter().map(|p| p.size).sum();
assert_eq!(total, size);
assert_eq!(part_vec[0].offset, 0);
assert_eq!(part_vec[1].offset, part_vec[0].size);
assert_eq!(part_vec[2].offset, part_vec[0].size + part_vec[1].size);
assert_eq!(
part_vec[3].offset,
part_vec[0].size + part_vec[1].size + part_vec[2].size
);
}
#[test]
fn test_determine_parts_uneven_split() {
let size = 1024 * 1024 + 123;
let max_connections = 3;
let parts = Download::determine_parts(Some(size), max_connections);
assert_eq!(parts.len(), max_connections as usize);
let mut part_vec: Vec<_> = parts.values().collect();
part_vec.sort_by_key(|p| p.offset);
let total: u64 = part_vec.iter().map(|p| p.size).sum();
assert_eq!(total, size);
assert!(part_vec[2].size >= part_vec[1].size);
assert_eq!(part_vec[0].size, part_vec[1].size);
}
#[test]
fn test_determine_parts_too_many_connections() {
let size = 900 * 1024; let max_connections = 10;
let parts = Download::determine_parts(Some(size), max_connections);
assert_eq!(parts.len(), 3);
let mut part_vec: Vec<_> = parts.values().collect();
part_vec.sort_by_key(|p| p.offset);
let total: u64 = part_vec.iter().map(|p| p.size).sum();
assert_eq!(total, size);
}
#[test]
fn test_determine_parts_800kb_file() {
let size = 800 * 1024;
let max_connections = 10; let parts = Download::determine_parts(Some(size), max_connections);
assert_eq!(parts.len(), 3);
let mut part_vec: Vec<_> = parts.values().collect();
part_vec.sort_by_key(|p| p.offset);
let total: u64 = part_vec.iter().map(|p| p.size).sum();
assert_eq!(total, size);
assert_eq!(part_vec[0].offset, 0);
assert_eq!(part_vec[1].offset, part_vec[0].offset + part_vec[0].size);
assert_eq!(part_vec[2].offset, part_vec[1].offset + part_vec[1].size);
assert_eq!(part_vec[0].size, part_vec[1].size);
assert!(part_vec[2].size >= part_vec[1].size);
assert_eq!(part_vec[0].offset % Download::ASSEMBLY_CLUSTER_SIZE, 0);
assert_eq!(part_vec[1].offset % Download::ASSEMBLY_CLUSTER_SIZE, 0);
assert_eq!(part_vec[2].offset % Download::ASSEMBLY_CLUSTER_SIZE, 0);
}
}