use crate::config::DownloadConfig;
use crate::utils::{format_speed, parse_content_range_total, pwrite_all};
use futures_util::StreamExt;
use indicatif::{ProgressBar, ProgressStyle};
use reqwest::header::{self, HeaderMap, HeaderValue};
use reqwest::{Client, StatusCode};
use std::fs::File;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::io::{AsyncWriteExt, BufWriter};
type BoxError = Box<dyn std::error::Error + Send + Sync>;
type Result<T> = std::result::Result<T, BoxError>;
pub struct DownloadOptimizer {
client: Client,
config: DownloadConfig,
}
struct Probe {
total_size: u64,
supports_ranges: bool,
final_url: String,
}
impl DownloadOptimizer {
pub async fn new(config: DownloadConfig) -> Result<Self> {
let client = build_client(&config)?;
Ok(Self { client, config })
}
pub async fn download(&self, url: &str, filename: &str) -> Result<()> {
let probe = self.probe(url).await?;
if probe.supports_ranges
&& probe.total_size > self.config.min_chunk_size
&& self.config.max_connections > 1
{
self.parallel_download(&probe.final_url, filename, probe.total_size)
.await
} else {
self.single_download(&probe.final_url, filename, probe.total_size)
.await
}
}
async fn probe(&self, url: &str) -> Result<Probe> {
if let Ok(resp) = self.client.head(url).send().await
&& resp.status().is_success()
{
let total = resp.content_length().unwrap_or(0);
let ranges = resp
.headers()
.get(header::ACCEPT_RANGES)
.and_then(|v| v.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("bytes"))
.unwrap_or(false);
let final_url = resp.url().to_string();
if ranges && total > 0 {
return Ok(Probe {
total_size: total,
supports_ranges: true,
final_url,
});
}
}
let resp = self
.client
.get(url)
.header(header::RANGE, "bytes=0-0")
.send()
.await?;
let final_url = resp.url().to_string();
if resp.status() == StatusCode::PARTIAL_CONTENT {
let total = resp
.headers()
.get(header::CONTENT_RANGE)
.and_then(|v| v.to_str().ok())
.and_then(parse_content_range_total)
.unwrap_or(0);
return Ok(Probe {
total_size: total,
supports_ranges: total > 0,
final_url,
});
}
if !resp.status().is_success() {
return Err(format!("HTTP error while probing: {}", resp.status()).into());
}
Ok(Probe {
total_size: resp.content_length().unwrap_or(0),
supports_ranges: false,
final_url,
})
}
async fn parallel_download(&self, url: &str, filename: &str, total_size: u64) -> Result<()> {
let file = File::create(filename)?;
file.set_len(total_size)?;
let file = Arc::new(file);
let piece_size = self.config.piece_size.max(64 * 1024);
let cursor = Arc::new(AtomicU64::new(0));
let downloaded = Arc::new(AtomicU64::new(0));
let pb = self.make_progress(total_size);
let progress = spawn_progress_updater(pb.clone(), downloaded.clone(), self.config.quiet);
let mut handles = Vec::with_capacity(self.config.max_connections);
for _ in 0..self.config.max_connections {
let client = self.client.clone();
let url = url.to_string();
let file = file.clone();
let cursor = cursor.clone();
let downloaded = downloaded.clone();
let config = self.config.clone();
handles.push(tokio::spawn(async move {
worker_loop(
client, url, file, total_size, piece_size, cursor, downloaded, config,
)
.await
}));
}
let mut first_error: Option<BoxError> = None;
for handle in handles {
match handle.await {
Ok(Ok(())) => {}
Ok(Err(e)) => {
first_error.get_or_insert(e);
}
Err(e) => {
first_error.get_or_insert(Box::new(e) as BoxError);
}
}
}
progress.stop().await;
if let Some(e) = first_error {
return Err(e);
}
finalize_progress(&pb, total_size);
Ok(())
}
async fn single_download(&self, url: &str, filename: &str, hint_size: u64) -> Result<()> {
let resp = self.client.get(url).send().await?;
if !resp.status().is_success() {
return Err(format!("HTTP error: {}", resp.status()).into());
}
let total = if hint_size > 0 {
hint_size
} else {
resp.content_length().unwrap_or(0)
};
let pb = self.make_progress(total);
let file = tokio::fs::File::create(filename).await?;
let mut writer = BufWriter::with_capacity(self.config.write_buffer, file);
let mut stream = resp.bytes_stream();
let mut downloaded: u64 = 0;
let mut last_draw = Instant::now();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
writer.write_all(&chunk).await?;
downloaded += chunk.len() as u64;
if !self.config.quiet && last_draw.elapsed() >= Duration::from_millis(80) {
pb.set_position(downloaded);
last_draw = Instant::now();
}
}
writer.flush().await?;
finalize_progress(&pb, downloaded);
Ok(())
}
fn make_progress(&self, total: u64) -> ProgressBar {
if self.config.quiet {
return ProgressBar::hidden();
}
if total == 0 {
return ProgressBar::new_spinner();
}
let pb = ProgressBar::new(total);
pb.set_style(
ProgressStyle::default_bar()
.template(
"{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] \
{bytes}/{total_bytes} ({bytes_per_sec}, {eta})",
)
.unwrap()
.progress_chars("#>-"),
);
pb
}
}
fn build_client(config: &DownloadConfig) -> Result<Client> {
let mut headers = HeaderMap::new();
headers.insert(
header::USER_AGENT,
HeaderValue::from_static(concat!("dw/", env!("CARGO_PKG_VERSION"))),
);
headers.insert(header::ACCEPT, HeaderValue::from_static("*/*"));
headers.insert(
header::ACCEPT_ENCODING,
HeaderValue::from_static("identity"),
);
let client = Client::builder()
.default_headers(headers)
.pool_max_idle_per_host(config.max_connections.max(1))
.http1_only()
.tcp_nodelay(true)
.tcp_keepalive(Duration::from_secs(60))
.connect_timeout(config.connect_timeout)
.build()?;
Ok(client)
}
#[allow(clippy::too_many_arguments)]
async fn worker_loop(
client: Client,
url: String,
file: Arc<File>,
total_size: u64,
piece_size: u64,
cursor: Arc<AtomicU64>,
downloaded: Arc<AtomicU64>,
config: DownloadConfig,
) -> Result<()> {
loop {
let start = cursor.fetch_add(piece_size, Ordering::Relaxed);
if start >= total_size {
break;
}
let end = (start + piece_size).min(total_size) - 1; download_piece(&client, &url, &file, start, end, &downloaded, &config).await?;
}
Ok(())
}
async fn download_piece(
client: &Client,
url: &str,
file: &Arc<File>,
start: u64,
end: u64,
downloaded: &Arc<AtomicU64>,
config: &DownloadConfig,
) -> Result<()> {
let piece_len = end - start + 1;
let mut flushed: u64 = 0;
let mut attempt: u32 = 0;
loop {
let resume_from = start + flushed;
let attempt_fut = stream_range(
client,
url,
file,
resume_from,
end,
&mut flushed,
downloaded,
config.write_buffer,
);
let outcome = if config.piece_timeout.is_zero() {
attempt_fut.await
} else {
match tokio::time::timeout(config.piece_timeout, attempt_fut).await {
Ok(res) => res,
Err(_) => Err("piece stalled (timeout)".into()),
}
};
match outcome {
Ok(()) if flushed >= piece_len => return Ok(()),
Ok(()) => {}
Err(e) => {
if attempt >= config.max_retries {
return Err(format!(
"piece {start}-{end} failed after {} attempts: {e}",
attempt + 1
)
.into());
}
}
}
attempt += 1;
if attempt > config.max_retries {
return Err(format!("piece {start}-{end} exceeded retry budget").into());
}
let backoff = Duration::from_millis(100u64 << attempt.min(6));
tokio::time::sleep(backoff).await;
}
}
#[allow(clippy::too_many_arguments)]
async fn stream_range(
client: &Client,
url: &str,
file: &Arc<File>,
resume_from: u64,
end: u64,
flushed: &mut u64,
downloaded: &Arc<AtomicU64>,
write_buffer: usize,
) -> Result<()> {
let range = format!("bytes={resume_from}-{end}");
let resp = client
.get(url)
.header(header::RANGE, range)
.send()
.await?;
if resp.status() != StatusCode::PARTIAL_CONTENT {
return Err(format!("expected 206 Partial Content, got {}", resp.status()).into());
}
let mut stream = resp.bytes_stream();
let mut buf: Vec<u8> = Vec::with_capacity(write_buffer);
let mut write_off = resume_from;
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
buf.extend_from_slice(&chunk);
if buf.len() >= write_buffer {
buf = flush_buf(file, &mut write_off, buf, flushed, downloaded).await?;
}
}
if !buf.is_empty() {
flush_buf(file, &mut write_off, buf, flushed, downloaded).await?;
}
Ok(())
}
async fn flush_buf(
file: &Arc<File>,
write_off: &mut u64,
buf: Vec<u8>,
flushed: &mut u64,
downloaded: &Arc<AtomicU64>,
) -> Result<Vec<u8>> {
let len = buf.len() as u64;
let file = file.clone();
let off = *write_off;
let mut buf = tokio::task::spawn_blocking(move || -> std::io::Result<Vec<u8>> {
pwrite_all(&file, off, &buf)?;
Ok(buf)
})
.await??;
buf.clear();
*write_off += len;
*flushed += len;
downloaded.fetch_add(len, Ordering::Relaxed);
Ok(buf)
}
struct ProgressUpdater {
handle: tokio::task::JoinHandle<()>,
stop: Arc<std::sync::atomic::AtomicBool>,
}
impl ProgressUpdater {
async fn stop(self) {
self.stop.store(true, Ordering::Relaxed);
let _ = self.handle.await;
}
}
fn spawn_progress_updater(
pb: ProgressBar,
downloaded: Arc<AtomicU64>,
quiet: bool,
) -> ProgressUpdater {
let stop = Arc::new(std::sync::atomic::AtomicBool::new(false));
let stop_flag = stop.clone();
let handle = tokio::spawn(async move {
if quiet {
return;
}
let mut ticker = tokio::time::interval(Duration::from_millis(80));
loop {
ticker.tick().await;
let done = downloaded.load(Ordering::Relaxed);
pb.set_position(done);
if stop_flag.load(Ordering::Relaxed) {
break;
}
}
});
ProgressUpdater { handle, stop }
}
fn finalize_progress(pb: &ProgressBar, total: u64) {
if pb.is_hidden() {
return;
}
pb.set_position(total);
let elapsed = pb.elapsed().as_secs_f64();
let speed = if elapsed > 0.0 {
total as f64 / elapsed
} else {
0.0
};
pb.finish_with_message(format!("done — avg {}", format_speed(speed)));
}