use std::io::SeekFrom;
use std::path::Path;
use std::sync::Arc;
use async_fs::{File, OpenOptions};
use futures_lite::io::AsyncSeekExt;
use futures_lite::{AsyncWriteExt, StreamExt};
use sha2::{Digest as Sha2Digest, Sha256};
use crate::client::Client;
use crate::error::{Error, ErrorKind, Result};
use crate::rate_limit::RateLimiter;
use crate::response::StatusCode;
pub use crate::rate_limit::RateLimiter as DownloadRateLimiter;
#[derive(Clone, Debug, Default)]
pub struct DownloadCapabilities {
pub supports_range: bool,
pub content_length: Option<u64>,
pub etag: Option<String>,
pub last_modified: Option<String>,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum HashAlgorithm {
Sha256,
Md5,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct DownloadDigest {
pub algorithm: HashAlgorithm,
pub hex: String,
}
pub struct DownloadBuilder<'a> {
client: &'a Client,
url: String,
chunks: usize,
rate_limiter: Option<RateLimiter>,
hash_algorithm: Option<HashAlgorithm>,
expected_hash: Option<String>,
}
impl<'a> DownloadBuilder<'a> {
pub(crate) fn new(client: &'a Client, url: impl Into<String>) -> Self {
Self {
client,
url: url.into(),
chunks: 1,
rate_limiter: None,
hash_algorithm: None,
expected_hash: None,
}
}
pub fn chunks(mut self, n: usize) -> Self {
self.chunks = n.max(1);
self
}
pub fn rate_limit(mut self, bytes_per_sec: u64) -> Self {
self.rate_limiter = Some(RateLimiter::new(bytes_per_sec));
self
}
pub fn rate_limit_shared(mut self, limiter: RateLimiter) -> Self {
self.rate_limiter = Some(limiter);
self
}
pub fn hash(mut self, algorithm: HashAlgorithm) -> Self {
self.hash_algorithm = Some(algorithm);
self
}
pub fn verify(mut self, algorithm: HashAlgorithm, expected_hex: impl Into<String>) -> Self {
self.hash_algorithm = Some(algorithm);
self.expected_hash = Some(expected_hex.into());
self
}
pub async fn probe(self) -> Result<DownloadCapabilities> {
probe_url(self.client, &self.url).await
}
pub async fn save(self, dest: &Path) -> Result<DownloadResult> {
let caps = probe_url(self.client, &self.url).await?;
if caps.supports_range && caps.content_length.is_some() && self.chunks > 1 {
download_chunked(
self.client,
&self.url,
dest,
&caps,
self.chunks,
self.rate_limiter,
self.hash_algorithm,
self.expected_hash,
)
.await
} else {
download_single(
self.client,
&self.url,
dest,
&caps,
self.rate_limiter,
self.hash_algorithm,
self.expected_hash,
)
.await
}
}
}
#[derive(Debug)]
pub struct DownloadResult {
pub total_bytes: u64,
pub chunks_used: usize,
pub capabilities: DownloadCapabilities,
pub digest: Option<DownloadDigest>,
}
async fn probe_url(client: &Client, url: &str) -> Result<DownloadCapabilities> {
let resp = client.head(url).await?;
if !resp.status().is_success() {
return Err(Error::new(
ErrorKind::Transport,
format!(
"probe: server returned {} for HEAD {url}",
resp.status().as_u16()
),
));
}
let h = resp.headers();
Ok(DownloadCapabilities {
supports_range: h
.get("accept-ranges")
.map(|v| v.eq_ignore_ascii_case("bytes"))
.unwrap_or(false),
content_length: h.get("content-length").and_then(|v| v.parse().ok()),
etag: h.get("etag").map(str::to_owned),
last_modified: h.get("last-modified").map(str::to_owned),
})
}
async fn download_single(
client: &Client,
url: &str,
dest: &Path,
caps: &DownloadCapabilities,
rate_limiter: Option<RateLimiter>,
hash_algorithm: Option<HashAlgorithm>,
expected_hash: Option<String>,
) -> Result<DownloadResult> {
let resp = client.get(url).await?;
if !resp.status().is_success() {
return Err(Error::new(
ErrorKind::Transport,
format!(
"download: server returned {} for {url}",
resp.status().as_u16()
),
));
}
let mut file = File::create(dest).await.map_err(io_err)?;
let mut stream = resp.bytes_stream();
let mut hash = hash_algorithm.map(RunningHash::new);
let mut total_bytes: u64 = 0;
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
if let Some(rl) = &rate_limiter {
rl.acquire(chunk.len()).await;
}
if let Some(h) = hash.as_mut() {
h.update(&chunk);
}
file.write_all(&chunk).await.map_err(io_err)?;
total_bytes += chunk.len() as u64;
}
file.flush().await.map_err(io_err)?;
let digest = hash.map(|h| h.finalize(hash_algorithm.unwrap()));
verify_digest(&digest, &expected_hash)?;
Ok(DownloadResult {
total_bytes,
chunks_used: 1,
capabilities: caps.clone(),
digest,
})
}
#[allow(clippy::too_many_arguments)]
async fn download_chunked(
client: &Client,
url: &str,
dest: &Path,
caps: &DownloadCapabilities,
n: usize,
rate_limiter: Option<RateLimiter>,
hash_algorithm: Option<HashAlgorithm>,
expected_hash: Option<String>,
) -> Result<DownloadResult> {
let total = caps.content_length.unwrap();
File::create(dest)
.await
.map_err(io_err)?
.set_len(total)
.await
.map_err(io_err)?;
let ranges = byte_ranges(total, n);
let actual = ranges.len();
let shared_rl: Option<Arc<RateLimiter>> = rate_limiter.map(Arc::new);
let url_arc = Arc::new(url.to_owned());
let mut futures = Vec::with_capacity(actual);
for (start, end) in ranges.iter().copied() {
let url = Arc::clone(&url_arc);
let rl = shared_rl.clone();
futures.push(async move { fetch_range(client, &url, start, end, rl).await });
}
let chunk_results = join_all(futures).await;
let mut file = OpenOptions::new()
.write(true)
.open(dest)
.await
.map_err(io_err)?;
let mut hash = hash_algorithm.map(RunningHash::new);
let mut total_bytes: u64 = 0;
for (idx, result) in chunk_results.into_iter().enumerate() {
let data = result?;
let (start, _) = ranges[idx];
file.seek(SeekFrom::Start(start)).await.map_err(io_err)?;
file.write_all(&data).await.map_err(io_err)?;
if let Some(h) = hash.as_mut() {
h.update(&data);
}
total_bytes += data.len() as u64;
}
file.flush().await.map_err(io_err)?;
let digest = hash.map(|h| h.finalize(hash_algorithm.unwrap()));
verify_digest(&digest, &expected_hash)?;
Ok(DownloadResult {
total_bytes,
chunks_used: actual,
capabilities: caps.clone(),
digest,
})
}
async fn fetch_range(
client: &Client,
url: &str,
start: u64,
end: u64,
rate_limiter: Option<Arc<RateLimiter>>,
) -> Result<Vec<u8>> {
let range_value = format!("bytes={start}-{end}");
let resp = client.get(url).header("Range", &range_value)?.await?;
let status = resp.status();
if status != StatusCode::PARTIAL_CONTENT && !status.is_success() {
return Err(Error::new(
ErrorKind::Transport,
format!("range {range_value} returned status {}", status.as_u16()),
));
}
let mut stream = resp.bytes_stream();
let mut buf = Vec::with_capacity((end - start + 1) as usize);
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
if let Some(rl) = &rate_limiter {
rl.acquire(chunk.len()).await;
}
buf.extend_from_slice(&chunk);
}
Ok(buf)
}
async fn join_all<F, T>(futures: Vec<F>) -> Vec<T>
where
F: std::future::Future<Output = T>,
{
let mut pinned: Vec<std::pin::Pin<Box<dyn std::future::Future<Output = T>>>> =
futures.into_iter().map(|f| Box::pin(f) as _).collect();
let mut results: Vec<Option<T>> = (0..pinned.len()).map(|_| None).collect();
let mut remaining = pinned.len();
futures_lite::future::poll_fn(|cx| {
for (i, fut) in pinned.iter_mut().enumerate() {
if results[i].is_some() {
continue;
}
if let std::task::Poll::Ready(v) = fut.as_mut().poll(cx) {
results[i] = Some(v);
remaining -= 1;
}
}
if remaining == 0 {
std::task::Poll::Ready(())
} else {
std::task::Poll::Pending
}
})
.await;
results.into_iter().map(|r| r.unwrap()).collect()
}
fn byte_ranges(total: u64, n: usize) -> Vec<(u64, u64)> {
let chunk_size = (total + n as u64 - 1) / n as u64;
let mut ranges = Vec::new();
let mut start = 0u64;
while start < total {
let end = (start + chunk_size - 1).min(total - 1);
ranges.push((start, end));
start = end + 1;
}
ranges
}
fn io_err(e: std::io::Error) -> Error {
Error::new(ErrorKind::Transport, format!("io: {e}"))
}
fn verify_digest(actual: &Option<DownloadDigest>, expected: &Option<String>) -> Result<()> {
if let (Some(a), Some(e)) = (actual.as_ref(), expected.as_ref()) {
if &a.hex != e {
return Err(Error::new(
ErrorKind::Decode,
format!("hash mismatch: expected {e} got {}", a.hex),
));
}
}
Ok(())
}
pub(crate) enum RunningHash {
Sha256(Sha256),
Md5(md5::Context),
}
impl RunningHash {
fn new(alg: HashAlgorithm) -> Self {
match alg {
HashAlgorithm::Sha256 => Self::Sha256(Sha256::new()),
HashAlgorithm::Md5 => Self::Md5(md5::Context::new()),
}
}
fn update(&mut self, data: &[u8]) {
match self {
Self::Sha256(h) => sha2::Digest::update(h, data),
Self::Md5(h) => h.consume(data),
}
}
fn finalize(self, alg: HashAlgorithm) -> DownloadDigest {
let hex = match self {
Self::Sha256(h) => format!("{:x}", sha2::Digest::finalize(h)),
Self::Md5(h) => format!("{:x}", h.finalize()),
};
DownloadDigest {
algorithm: alg,
hex,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn byte_ranges_even() {
let r = byte_ranges(100, 4);
assert_eq!(r, vec![(0, 24), (25, 49), (50, 74), (75, 99)]);
}
#[test]
fn byte_ranges_uneven() {
let r = byte_ranges(10, 3);
assert_eq!(r.len(), 3);
assert_eq!(r[0], (0, 3));
assert_eq!(r[1], (4, 7));
assert_eq!(r[2], (8, 9));
}
#[test]
fn byte_ranges_single() {
assert_eq!(byte_ranges(50, 1), vec![(0, 49)]);
}
#[test]
fn byte_ranges_more_chunks_than_bytes() {
let r = byte_ranges(3, 10);
assert_eq!(r, vec![(0, 0), (1, 1), (2, 2)]);
}
#[test]
fn sha256_known_vector() {
let mut h = RunningHash::new(HashAlgorithm::Sha256);
h.update(b"hello");
let d = h.finalize(HashAlgorithm::Sha256);
assert_eq!(
d.hex,
"2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"
);
}
#[test]
fn md5_known_vector() {
let mut h = RunningHash::new(HashAlgorithm::Md5);
h.update(b"hello");
let d = h.finalize(HashAlgorithm::Md5);
assert_eq!(d.hex, "5d41402abc4b2a76b9719d911017c592");
}
#[test]
fn verify_ok() {
let d = Some(DownloadDigest {
algorithm: HashAlgorithm::Sha256,
hex: "abc".into(),
});
assert!(verify_digest(&d, &Some("abc".into())).is_ok());
}
#[test]
fn verify_mismatch() {
let d = Some(DownloadDigest {
algorithm: HashAlgorithm::Sha256,
hex: "abc".into(),
});
let e = verify_digest(&d, &Some("xyz".into())).unwrap_err();
assert_eq!(e.kind(), &ErrorKind::Decode);
}
}