use std::path::{Path, PathBuf};
use futures_util::StreamExt;
use pulith_fs::workflow::Workspace;
use pulith_verify::{Hasher, Sha256Hasher};
use serde::{Deserialize, Serialize};
use crate::config::{FetchOptions, FetchPhase};
use crate::error::{Error, Result};
use crate::net::http::HttpClient;
use crate::progress::PerformanceMetrics;
use crate::progress::Progress;
use crate::rate::retry_delay;
pub struct Fetcher<C: HttpClient> {
pub(crate) client: C,
workspace_root: PathBuf,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum FetchSource {
Url(String),
LocalPath(PathBuf),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct FetchReceipt {
pub source: FetchSource,
pub destination: PathBuf,
pub bytes_downloaded: u64,
pub total_bytes: Option<u64>,
pub sha256_hex: Option<String>,
}
impl<C: HttpClient> Fetcher<C> {
pub fn new(client: C, workspace_root: impl Into<PathBuf>) -> Self {
Self {
client,
workspace_root: workspace_root.into(),
}
}
#[tracing::instrument(skip(self), fields(url = %url))]
pub async fn head(&self, url: &str) -> Result<Option<u64>> {
self.client
.head(url)
.await
.map_err(|e| Error::Network(e.to_string()))
}
#[tracing::instrument(skip(self, options), fields(url = %url, destination = %destination.display()))]
pub async fn fetch_with_receipt(
&self,
url: &str,
destination: &Path,
options: FetchOptions,
) -> Result<FetchReceipt> {
let mut attempt = 0u32;
loop {
match self
.fetch_with_receipt_attempt(url, destination, &options, attempt)
.await
{
Ok(receipt) => return Ok(receipt),
Err(error) => {
if !matches!(error, Error::Network(_) | Error::Timeout(_)) {
return Err(error);
}
if attempt >= options.retry_policy.max_retries {
return Err(Error::MaxRetriesExceeded { count: attempt + 1 });
}
let delay = retry_delay(attempt, options.retry_policy.base_backoff);
if let Some(provider) = &options.retry_delay_provider {
(provider)(delay).await;
} else {
tokio::time::sleep(delay).await;
}
attempt += 1;
}
}
}
}
#[tracing::instrument(skip(self, options), fields(url = %url, destination = %destination.display(), retry_count = retry_count))]
async fn fetch_with_receipt_attempt(
&self,
url: &str,
destination: &Path,
options: &FetchOptions,
retry_count: u32,
) -> Result<FetchReceipt> {
let start_time = std::time::Instant::now();
let mut performance_metrics = PerformanceMetrics::default();
let connecting_start = std::time::Instant::now();
self.report_progress(
options,
Progress {
phase: FetchPhase::Connecting,
bytes_downloaded: 0,
total_bytes: None,
retry_count,
performance_metrics: Some(performance_metrics.clone()),
},
);
let total_bytes = options.expected_bytes.or(self
.client
.head(url)
.await
.map_err(|e| Error::Network(e.to_string()))?);
let connecting_duration = connecting_start.elapsed();
performance_metrics.phase_timings.connecting_ms = connecting_duration.as_millis() as u64;
performance_metrics.connection_time_ms = Some(connecting_duration.as_millis() as u64);
self.report_progress(
options,
Progress {
phase: FetchPhase::Connecting,
bytes_downloaded: 0,
total_bytes,
retry_count,
performance_metrics: Some(performance_metrics.clone()),
},
);
let mut request_headers: Vec<(String, String)> = options.headers.iter().cloned().collect();
if let Some(offset) = options.resume_offset {
request_headers.push(("Range".to_string(), format!("bytes={offset}-")));
}
let staging_dir = self.workspace_root.join("staging");
let dest_dir = destination.parent().unwrap_or_else(|| Path::new("."));
let workspace = Workspace::new(&staging_dir, dest_dir)?;
let staging_file_path = workspace.path().join(
destination
.file_name()
.unwrap_or_else(|| std::ffi::OsStr::new("download")),
);
let mut stream = self
.client
.stream(url, &request_headers)
.await
.map_err(|e| Error::Network(e.to_string()))?;
let mut hasher = Sha256Hasher::new();
let downloading_start = std::time::Instant::now();
self.report_progress(
options,
Progress {
phase: FetchPhase::Downloading,
bytes_downloaded: options.resume_offset.unwrap_or(0),
total_bytes,
retry_count,
performance_metrics: Some(performance_metrics.clone()),
},
);
let mut bytes_downloaded = options.resume_offset.unwrap_or(0);
let mut last_progress_time = std::time::Instant::now();
let mut last_bytes_downloaded = bytes_downloaded;
use tokio::io::AsyncWriteExt;
let mut file = tokio::fs::File::create(&staging_file_path)
.await
.map_err(|e| Error::Network(e.to_string()))?;
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| Error::Network(e.to_string()))?;
hasher.update(&chunk);
file.write_all(&chunk)
.await
.map_err(|e| Error::Network(e.to_string()))?;
bytes_downloaded += chunk.len() as u64;
let now = std::time::Instant::now();
if now.duration_since(last_progress_time).as_millis() >= 100 {
let time_diff = now.duration_since(last_progress_time).as_secs_f64();
let bytes_diff = bytes_downloaded - last_bytes_downloaded;
if time_diff > 0.0 {
performance_metrics.current_rate_bps = Some(bytes_diff as f64 / time_diff);
}
last_progress_time = now;
last_bytes_downloaded = bytes_downloaded;
}
let total_time = start_time.elapsed().as_secs_f64();
if total_time > 0.0 {
performance_metrics.average_rate_bps = Some(bytes_downloaded as f64 / total_time);
}
self.report_progress(
options,
Progress {
phase: FetchPhase::Downloading,
bytes_downloaded,
total_bytes,
retry_count,
performance_metrics: Some(performance_metrics.clone()),
},
);
}
let downloading_duration = downloading_start.elapsed();
performance_metrics.phase_timings.downloading_ms = downloading_duration.as_millis() as u64;
let verifying_start = std::time::Instant::now();
self.report_progress(
options,
Progress {
phase: FetchPhase::Verifying,
bytes_downloaded,
total_bytes,
retry_count,
performance_metrics: Some(performance_metrics.clone()),
},
);
let actual_checksum = hasher.finalize();
if let Some(expected_checksum) = options.checksum
&& actual_checksum != expected_checksum
{
return Err(Error::ChecksumMismatch {
expected: hex::encode(expected_checksum),
actual: hex::encode(actual_checksum),
});
}
let verifying_duration = verifying_start.elapsed();
performance_metrics.phase_timings.verifying_ms = verifying_duration.as_millis() as u64;
drop(file);
let committing_start = std::time::Instant::now();
self.report_progress(
options,
Progress {
phase: FetchPhase::Committing,
bytes_downloaded,
total_bytes,
retry_count,
performance_metrics: Some(performance_metrics.clone()),
},
);
workspace
.commit()
.map_err(|e| Error::Network(e.to_string()))?;
let committing_duration = committing_start.elapsed();
performance_metrics.phase_timings.committing_ms = committing_duration.as_millis() as u64;
self.report_progress(
options,
Progress {
phase: FetchPhase::Completed,
bytes_downloaded,
total_bytes,
retry_count,
performance_metrics: Some(performance_metrics),
},
);
Ok(FetchReceipt {
source: FetchSource::Url(url.to_string()),
destination: destination.to_path_buf(),
bytes_downloaded,
total_bytes,
sha256_hex: Some(hex::encode(actual_checksum)),
})
}
fn report_progress(&self, options: &FetchOptions, progress: Progress) {
if let Some(ref callback) = options.on_progress {
callback(&progress);
}
}
#[tracing::instrument(skip(self, source, options), fields(source = %source.url, destination = %destination.display()))]
pub async fn try_source(
&self,
source: &crate::DownloadSource,
destination: &Path,
options: &FetchOptions,
) -> Result<FetchReceipt> {
let mut fetch_options = options.clone();
fetch_options.checksum = source.checksum;
self.fetch_with_receipt(&source.url, destination, fetch_options)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{FetchOptions, FetchPhase};
use crate::net::http::BoxStream;
use crate::progress::Progress;
use bytes::Bytes;
use std::path::PathBuf;
use std::sync::Arc;
#[derive(Debug)]
struct MockError(String);
impl std::fmt::Display for MockError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for MockError {}
struct MockHttpClient {
should_fail: bool,
content_length: Option<u64>,
}
impl MockHttpClient {
fn new() -> Self {
Self {
should_fail: false,
content_length: Some(1024),
}
}
fn with_error() -> Self {
Self {
should_fail: true,
content_length: None,
}
}
fn without_content_length() -> Self {
Self {
should_fail: false,
content_length: None,
}
}
}
impl HttpClient for MockHttpClient {
type Error = MockError;
async fn stream(
&self,
_url: &str,
_headers: &[(String, String)],
) -> std::result::Result<
BoxStream<'static, std::result::Result<Bytes, Self::Error>>,
Self::Error,
> {
if self.should_fail {
Err(MockError("Stream failed".to_string()))
} else {
let stream = futures_util::stream::once(async { Ok(Bytes::from("test data")) });
Ok(Box::pin(stream)
as BoxStream<
'static,
std::result::Result<Bytes, Self::Error>,
>)
}
}
async fn head(&self, _url: &str) -> std::result::Result<Option<u64>, Self::Error> {
if self.should_fail {
Err(MockError("HEAD request failed".to_string()))
} else {
Ok(self.content_length)
}
}
}
#[tokio::test]
async fn test_fetcher_new() {
let client = MockHttpClient::new();
let workspace_root = "/tmp/test_workspace";
let fetcher = Fetcher::new(client, workspace_root);
assert_eq!(fetcher.workspace_root, PathBuf::from(workspace_root));
}
#[tokio::test]
async fn test_fetcher_head_success() {
let client = MockHttpClient::new();
let fetcher = Fetcher::new(client, "/tmp");
let result = fetcher.head("http://example.com").await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), Some(1024));
}
#[tokio::test]
async fn test_fetcher_head_without_content_length() {
let client = MockHttpClient::without_content_length();
let fetcher = Fetcher::new(client, "/tmp");
let result = fetcher.head("http://example.com").await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), None);
}
#[tokio::test]
async fn test_fetcher_head_error() {
let client = MockHttpClient::with_error();
let fetcher = Fetcher::new(client, "/tmp");
let result = fetcher.head("http://example.com").await;
assert!(result.is_err());
match result.unwrap_err() {
Error::Network(msg) => assert!(msg.contains("HEAD request failed")),
_ => panic!("Expected Network error"),
}
}
#[tokio::test]
async fn test_fetcher_fetch_success() {
let client = MockHttpClient::new();
let fetcher = Fetcher::new(client, "/tmp");
let url = "http://example.com";
let destination = PathBuf::from("/tmp/test_file");
let options = FetchOptions::default();
let result = fetcher.fetch_with_receipt(url, &destination, options).await;
assert!(result.is_ok() || result.is_err());
}
#[tokio::test]
async fn test_fetcher_fetch_with_progress_callback() {
let client = MockHttpClient::new();
let fetcher = Fetcher::new(client, "/tmp");
let url = "http://example.com";
let destination = PathBuf::from("/tmp/test_file");
let progress_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
let progress_called_clone = progress_called.clone();
let options = FetchOptions {
on_progress: Some(Arc::new(move |_progress| {
progress_called_clone.store(true, std::sync::atomic::Ordering::Relaxed);
})),
..Default::default()
};
let _result = fetcher.fetch_with_receipt(url, &destination, options).await;
let _ = progress_called.load(std::sync::atomic::Ordering::Relaxed);
}
#[tokio::test]
async fn test_try_source() {
let client = MockHttpClient::new();
let fetcher = Fetcher::new(client, "/tmp");
let source = crate::DownloadSource::new("http://example.com".to_string());
let destination = PathBuf::from("/tmp/test_file");
let options = FetchOptions::default();
let result = fetcher.try_source(&source, &destination, &options).await;
assert!(result.is_ok() || result.is_err());
}
#[tokio::test]
async fn fetch_retries_with_explicit_retry_policy() {
use std::sync::atomic::{AtomicU32, Ordering};
struct AlwaysFailingHttpClient {
stream_calls: Arc<AtomicU32>,
}
impl HttpClient for AlwaysFailingHttpClient {
type Error = MockError;
async fn stream(
&self,
_url: &str,
_headers: &[(String, String)],
) -> std::result::Result<
BoxStream<'static, std::result::Result<Bytes, Self::Error>>,
Self::Error,
> {
let _ = self.stream_calls.fetch_add(1, Ordering::SeqCst);
Err(MockError("stream always fails".to_string()))
}
async fn head(&self, _url: &str) -> std::result::Result<Option<u64>, Self::Error> {
Ok(Some(9))
}
}
let stream_calls = Arc::new(AtomicU32::new(0));
let client = AlwaysFailingHttpClient {
stream_calls: Arc::clone(&stream_calls),
};
let temp = tempfile::tempdir().unwrap();
let fetcher = Fetcher::new(client, temp.path());
let options = FetchOptions::default().retry_policy(crate::RetryPolicy {
max_retries: 1,
base_backoff: std::time::Duration::from_millis(1),
});
let error = fetcher
.fetch_with_receipt(
"http://example.com",
&temp.path().join("retry.bin"),
options,
)
.await
.unwrap_err();
assert!(matches!(error, Error::MaxRetriesExceeded { count: 2 }));
assert_eq!(stream_calls.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn fetch_retries_can_use_custom_delay_provider() {
use std::sync::atomic::{AtomicU32, Ordering};
struct AlwaysFailingHttpClient;
impl HttpClient for AlwaysFailingHttpClient {
type Error = MockError;
async fn stream(
&self,
_url: &str,
_headers: &[(String, String)],
) -> std::result::Result<
BoxStream<'static, std::result::Result<Bytes, Self::Error>>,
Self::Error,
> {
Err(MockError("stream always fails".to_string()))
}
async fn head(&self, _url: &str) -> std::result::Result<Option<u64>, Self::Error> {
Ok(Some(9))
}
}
let delay_calls = Arc::new(AtomicU32::new(0));
let delay_calls_for_provider = Arc::clone(&delay_calls);
let temp = tempfile::tempdir().unwrap();
let fetcher = Fetcher::new(AlwaysFailingHttpClient, temp.path());
let options = FetchOptions::default()
.retry_policy(crate::RetryPolicy {
max_retries: 2,
base_backoff: std::time::Duration::from_millis(1),
})
.retry_delay_provider(Arc::new(move |_delay| {
let delay_calls_for_provider = Arc::clone(&delay_calls_for_provider);
Box::pin(async move {
delay_calls_for_provider.fetch_add(1, Ordering::SeqCst);
})
}));
let error = fetcher
.fetch_with_receipt(
"http://example.com",
&temp.path().join("retry-custom-delay.bin"),
options,
)
.await
.unwrap_err();
assert!(matches!(error, Error::MaxRetriesExceeded { count: 3 }));
assert_eq!(delay_calls.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn fetch_applies_resume_offset_as_range_header() {
use std::sync::Mutex;
struct HeaderCaptureHttpClient {
seen_headers: Arc<Mutex<Vec<(String, String)>>>,
}
impl HttpClient for HeaderCaptureHttpClient {
type Error = MockError;
async fn stream(
&self,
_url: &str,
headers: &[(String, String)],
) -> std::result::Result<
BoxStream<'static, std::result::Result<Bytes, Self::Error>>,
Self::Error,
> {
*self.seen_headers.lock().unwrap() = headers.to_vec();
Err(MockError("fail after header capture".to_string()))
}
async fn head(&self, _url: &str) -> std::result::Result<Option<u64>, Self::Error> {
Ok(Some(256))
}
}
let seen_headers = Arc::new(Mutex::new(Vec::<(String, String)>::new()));
let client = HeaderCaptureHttpClient {
seen_headers: Arc::clone(&seen_headers),
};
let temp = tempfile::tempdir().unwrap();
let fetcher = Fetcher::new(client, temp.path());
let options = FetchOptions::default()
.retry_policy(crate::RetryPolicy {
max_retries: 0,
base_backoff: std::time::Duration::from_millis(1),
})
.resume_offset(Some(128))
.expected_bytes(Some(256));
let error = fetcher
.fetch_with_receipt(
"http://example.com",
&temp.path().join("resume.bin"),
options,
)
.await
.unwrap_err();
assert!(matches!(error, Error::MaxRetriesExceeded { count: 1 }));
let headers = seen_headers.lock().unwrap().clone();
assert!(
headers
.iter()
.any(|(k, v)| k == "Range" && v == "bytes=128-")
);
}
#[test]
fn test_report_progress_without_callback() {
let client = MockHttpClient::new();
let fetcher = Fetcher::new(client, "/tmp");
let options = FetchOptions::default();
let progress = Progress {
phase: FetchPhase::Connecting,
bytes_downloaded: 0,
total_bytes: None,
retry_count: 0,
performance_metrics: None,
};
fetcher.report_progress(&options, progress);
}
#[test]
fn test_report_progress_with_callback() {
let client = MockHttpClient::new();
let fetcher = Fetcher::new(client, "/tmp");
let callback_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
let callback_called_clone = callback_called.clone();
let options = FetchOptions {
on_progress: Some(Arc::new(move |_progress| {
callback_called_clone.store(true, std::sync::atomic::Ordering::Relaxed);
})),
..Default::default()
};
let progress = Progress {
phase: FetchPhase::Connecting,
bytes_downloaded: 0,
total_bytes: None,
retry_count: 0,
performance_metrics: None,
};
fetcher.report_progress(&options, progress);
assert!(callback_called.load(std::sync::atomic::Ordering::Relaxed));
}
}