use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{watch, Semaphore};
use tokio::task::JoinHandle;
use crate::config::{DownloadSpec, LogLevel};
use crate::error::DownloadError;
use crate::logging::next_download_id;
use crate::network::ClientNetworkConfig;
use crate::progress::ProgressSnapshot;
use crate::session;
pub struct Downloader {
client: reqwest::Client,
client_config: ClientNetworkConfig,
log_level: LogLevel,
concurrency_limit: Option<Arc<Semaphore>>,
}
pub struct DownloaderBuilder {
client_config: ClientNetworkConfig,
log_level: LogLevel,
max_concurrent_downloads: Option<usize>,
}
impl DownloaderBuilder {
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.client_config.connect_timeout = timeout;
self
}
pub fn all_proxy(mut self, proxy: impl Into<String>) -> Self {
self.client_config.all_proxy = Some(proxy.into());
self
}
pub fn http_proxy(mut self, proxy: impl Into<String>) -> Self {
self.client_config.http_proxy = Some(proxy.into());
self
}
pub fn https_proxy(mut self, proxy: impl Into<String>) -> Self {
self.client_config.https_proxy = Some(proxy.into());
self
}
pub fn dns_server(mut self, server: std::net::SocketAddr) -> Self {
self.client_config.dns_servers.push(server);
self
}
pub fn dns_servers<I>(mut self, servers: I) -> Self
where
I: IntoIterator<Item = std::net::SocketAddr>,
{
self.client_config.dns_servers = servers.into_iter().collect();
self
}
pub fn enable_ipv6(mut self, enabled: bool) -> Self {
self.client_config.enable_ipv6 = enabled;
self
}
pub fn log_level(mut self, level: LogLevel) -> Self {
self.log_level = level;
self
}
pub fn max_concurrent_downloads(mut self, limit: usize) -> Self {
self.max_concurrent_downloads = Some(limit);
self
}
pub fn build(self) -> Result<Downloader, DownloadError> {
let log_level = self.log_level;
let client = self.client_config.build_client()?;
log_debug!(
log_level,
log_level = %log_level,
connect_timeout_ms = self.client_config.connect_timeout.as_millis() as u64,
has_proxy = self.client_config.all_proxy.is_some()
|| self.client_config.http_proxy.is_some()
|| self.client_config.https_proxy.is_some(),
custom_dns_count = self.client_config.dns_servers.len(),
ipv6 = self.client_config.enable_ipv6,
"downloader built"
);
Ok(Downloader {
client,
client_config: self.client_config,
log_level,
concurrency_limit: self
.max_concurrent_downloads
.map(|n| Arc::new(Semaphore::new(n))),
})
}
}
impl Downloader {
pub fn builder() -> DownloaderBuilder {
DownloaderBuilder {
client_config: ClientNetworkConfig::default(),
log_level: LogLevel::default(),
max_concurrent_downloads: None,
}
}
pub fn download(&self, spec: DownloadSpec) -> DownloadHandle {
let (progress_tx, progress_rx) = watch::channel(ProgressSnapshot::default());
let (cancel_tx, cancel_rx) = watch::channel(session::StopSignal::Running);
let log_level = self.log_level;
let download_id = next_download_id();
if let Err(error) = spec.validate() {
log_error!(
log_level,
download_id,
url = %spec.url,
error = %error,
"download task rejected due to invalid configuration"
);
let task = tokio::spawn(async move { Err(error) });
return DownloadHandle {
progress_rx,
cancel_tx,
task,
};
}
let shared_client = self.client.clone();
let client_config = self.client_config.clone();
let output = spec
.output_path
.as_ref()
.map(|path| path.display().to_string())
.unwrap_or_else(|| "<auto>".to_string());
log_info!(
log_level,
download_id,
url = %spec.url,
output = %output,
max_connections = spec.max_connections,
resume = spec.resume,
"download task created"
);
let concurrency_limit = self.concurrency_limit.clone();
let task = tokio::spawn(async move {
let _permit = match &concurrency_limit {
Some(sem) => Some(sem.acquire().await.map_err(|_| {
DownloadError::Internal("concurrency semaphore closed".into())
})?),
None => None,
};
let client = if spec.connect_timeout == client_config.connect_timeout {
shared_client
} else {
client_config
.with_connect_timeout(spec.connect_timeout)
.build_client()?
};
session::run_download(client, spec, log_level, download_id, progress_tx, cancel_rx)
.await
});
DownloadHandle {
progress_rx,
cancel_tx,
task,
}
}
}
pub struct DownloadHandle {
progress_rx: watch::Receiver<ProgressSnapshot>,
cancel_tx: watch::Sender<session::StopSignal>,
task: JoinHandle<Result<(), DownloadError>>,
}
impl DownloadHandle {
pub fn progress(&self) -> ProgressSnapshot {
self.progress_rx.borrow().clone()
}
pub fn subscribe_progress(&self) -> watch::Receiver<ProgressSnapshot> {
self.progress_rx.clone()
}
pub fn on_progress<F>(&self, callback: F)
where
F: Fn(ProgressSnapshot) + Send + 'static,
{
let mut rx = self.progress_rx.clone();
tokio::spawn(async move {
while rx.changed().await.is_ok() {
let snap = rx.borrow().clone();
let terminal = matches!(
snap.state,
crate::progress::DownloadState::Completed
| crate::progress::DownloadState::Failed
| crate::progress::DownloadState::Cancelled
| crate::progress::DownloadState::Paused
);
callback(snap);
if terminal {
break;
}
}
});
}
pub fn cancel(&self) {
let _ = self.cancel_tx.send(session::StopSignal::Cancel);
}
pub fn pause(&self) {
let _ = self.cancel_tx.send(session::StopSignal::Pause);
}
pub async fn wait(self) -> Result<(), DownloadError> {
match self.task.await {
Ok(result) => result,
Err(e) => Err(DownloadError::TaskFailed(format!("task panicked: {e}"))),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_downloader_builder_default() {
let downloader = Downloader::builder().build().unwrap();
drop(downloader);
}
#[test]
fn test_downloader_builder_with_log_level() {
let downloader = Downloader::builder()
.log_level(crate::config::LogLevel::Debug)
.build()
.unwrap();
drop(downloader);
}
#[test]
fn test_downloader_builder_custom_timeout() {
let downloader = Downloader::builder()
.connect_timeout(Duration::from_secs(10))
.build()
.unwrap();
drop(downloader);
}
#[test]
fn test_downloader_builder_proxy_and_dns_options() {
let downloader = Downloader::builder()
.all_proxy("http://127.0.0.1:7890")
.dns_server(std::net::SocketAddr::from(([1, 1, 1, 1], 53)))
.enable_ipv6(false)
.build()
.unwrap();
drop(downloader);
}
#[tokio::test]
async fn test_download_handle_progress_default() {
let downloader = Downloader::builder().build().unwrap();
let spec = crate::config::DownloadSpec::new("http://127.0.0.1:1/nonexistent")
.output_path(std::env::temp_dir().join("bytehaul_test_never_created"));
let handle = downloader.download(spec);
let progress = handle.progress();
assert_eq!(progress.state, crate::progress::DownloadState::Pending);
let _rx = handle.subscribe_progress();
handle.cancel();
let result = handle.wait().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_download_with_logging_enabled() {
let downloader = Downloader::builder()
.log_level(crate::config::LogLevel::Debug)
.build()
.unwrap();
let spec = crate::config::DownloadSpec::new("http://127.0.0.1:1/nonexistent")
.output_path(std::env::temp_dir().join("bytehaul_test_log_enabled"));
let handle = downloader.download(spec);
handle.cancel();
let result = handle.wait().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_download_rejects_invalid_spec_before_network_work() {
let downloader = Downloader::builder().build().unwrap();
let spec = crate::config::DownloadSpec::new("http://127.0.0.1:1/nonexistent")
.output_path(std::env::temp_dir().join("bytehaul_test_invalid_spec"))
.max_connections(0);
let handle = downloader.download(spec);
let err = handle.wait().await.unwrap_err();
assert!(matches!(err, crate::error::DownloadError::InvalidConfig(message) if message.contains("max_connections")));
}
#[test]
fn test_downloader_builder_max_concurrent_downloads() {
let d = Downloader::builder()
.max_concurrent_downloads(3)
.build()
.unwrap();
let sem = d.concurrency_limit.as_ref().expect("semaphore should exist");
assert_eq!(sem.available_permits(), 3);
}
#[test]
fn test_downloader_builder_no_concurrency_limit_by_default() {
let d = Downloader::builder().build().unwrap();
assert!(d.concurrency_limit.is_none());
}
#[test]
fn test_downloader_builder_sets_scheme_specific_proxies_and_dns_servers() {
let servers = vec![
std::net::SocketAddr::from(([1, 1, 1, 1], 53)),
std::net::SocketAddr::from(([8, 8, 8, 8], 53)),
];
let builder = Downloader::builder()
.http_proxy("http://127.0.0.1:8080")
.https_proxy("http://127.0.0.1:8443")
.dns_servers(servers.clone());
assert_eq!(
builder.client_config.http_proxy.as_deref(),
Some("http://127.0.0.1:8080")
);
assert_eq!(
builder.client_config.https_proxy.as_deref(),
Some("http://127.0.0.1:8443")
);
assert_eq!(builder.client_config.dns_servers, servers);
}
#[tokio::test]
async fn test_download_rebuilds_client_for_spec_timeout_override() {
let downloader = Downloader::builder().build().unwrap();
let mut spec = crate::config::DownloadSpec::new("http://127.0.0.1:1/nonexistent")
.output_path(std::env::temp_dir().join("bytehaul_test_timeout_override"));
spec.connect_timeout = Duration::from_secs(1);
let handle = downloader.download(spec);
let result = handle.wait().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_download_reports_closed_concurrency_semaphore() {
let downloader = Downloader::builder()
.max_concurrent_downloads(1)
.build()
.unwrap();
downloader
.concurrency_limit
.as_ref()
.expect("semaphore should exist")
.close();
let spec = crate::config::DownloadSpec::new("http://127.0.0.1:1/nonexistent")
.output_path(std::env::temp_dir().join("bytehaul_test_closed_semaphore"));
let err = downloader.download(spec).wait().await.unwrap_err();
assert!(matches!(err, crate::error::DownloadError::Internal(message) if message.contains("concurrency semaphore closed")));
}
#[tokio::test]
async fn test_download_handle_wait_maps_panics() {
let (progress_tx, progress_rx) = watch::channel(ProgressSnapshot::default());
let (cancel_tx, _) = watch::channel(session::StopSignal::Running);
drop(progress_tx);
let handle = DownloadHandle {
progress_rx,
cancel_tx,
task: tokio::spawn(async {
panic!("boom");
#[allow(unreachable_code)]
Ok(())
}),
};
let err = handle.wait().await.unwrap_err().to_string();
assert!(err.contains("task panicked"));
}
#[tokio::test]
async fn test_on_progress_receives_updates() {
let (progress_tx, progress_rx) = watch::channel(ProgressSnapshot::default());
let (cancel_tx, _) = watch::channel(session::StopSignal::Running);
let task = tokio::spawn(async { Ok(()) });
let handle = DownloadHandle {
progress_rx,
cancel_tx,
task,
};
let received = Arc::new(std::sync::atomic::AtomicU32::new(0));
let received_clone = received.clone();
handle.on_progress(move |_snap| {
received_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
});
let snap = ProgressSnapshot {
state: crate::progress::DownloadState::Downloading,
downloaded: 100,
..Default::default()
};
progress_tx.send(snap).unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
let snap = ProgressSnapshot {
state: crate::progress::DownloadState::Completed,
..Default::default()
};
progress_tx.send(snap).unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(received.load(std::sync::atomic::Ordering::Relaxed) >= 2);
}
#[tokio::test]
async fn test_on_progress_stops_for_all_terminal_states() {
use std::sync::atomic::{AtomicU32, Ordering};
let terminal_states = [
crate::progress::DownloadState::Completed,
crate::progress::DownloadState::Failed,
crate::progress::DownloadState::Cancelled,
crate::progress::DownloadState::Paused,
];
for state in terminal_states {
let (progress_tx, progress_rx) = watch::channel(ProgressSnapshot::default());
let (cancel_tx, _) = watch::channel(session::StopSignal::Running);
let received = Arc::new(AtomicU32::new(0));
let received_clone = received.clone();
let handle = DownloadHandle {
progress_rx,
cancel_tx,
task: tokio::spawn(async { Ok(()) }),
};
handle.on_progress(move |_snap| {
received_clone.fetch_add(1, Ordering::Relaxed);
});
progress_tx
.send(ProgressSnapshot {
state,
..Default::default()
})
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
progress_tx
.send(ProgressSnapshot {
state: crate::progress::DownloadState::Downloading,
downloaded: 1,
..Default::default()
})
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(received.load(Ordering::Relaxed), 1);
}
}
}