1use std::path::{Path, PathBuf};
2
3use futures_util::StreamExt;
4use pulith_fs::workflow::Workspace;
5use pulith_verify::{Hasher, Sha256Hasher};
6use serde::{Deserialize, Serialize};
7
8use crate::config::{FetchOptions, FetchPhase};
9use crate::error::{Error, Result};
10use crate::net::http::HttpClient;
11use crate::progress::PerformanceMetrics;
12use crate::progress::Progress;
13use crate::rate::retry_delay;
14
15pub struct Fetcher<C: HttpClient> {
17 pub(crate) client: C,
18 workspace_root: PathBuf,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
22pub enum FetchSource {
23 Url(String),
24 LocalPath(PathBuf),
25}
26
27#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
28pub struct FetchReceipt {
29 pub source: FetchSource,
30 pub destination: PathBuf,
31 pub bytes_downloaded: u64,
32 pub total_bytes: Option<u64>,
33 pub sha256_hex: Option<String>,
34}
35
36impl<C: HttpClient> Fetcher<C> {
37 pub fn new(client: C, workspace_root: impl Into<PathBuf>) -> Self {
39 Self {
40 client,
41 workspace_root: workspace_root.into(),
42 }
43 }
44
45 #[tracing::instrument(skip(self), fields(url = %url))]
47 pub async fn head(&self, url: &str) -> Result<Option<u64>> {
48 self.client
49 .head(url)
50 .await
51 .map_err(|e| Error::Network(e.to_string()))
52 }
53
54 #[tracing::instrument(skip(self, options), fields(url = %url, destination = %destination.display()))]
59 pub async fn fetch_with_receipt(
60 &self,
61 url: &str,
62 destination: &Path,
63 options: FetchOptions,
64 ) -> Result<FetchReceipt> {
65 let mut attempt = 0u32;
66 loop {
67 match self
68 .fetch_with_receipt_attempt(url, destination, &options, attempt)
69 .await
70 {
71 Ok(receipt) => return Ok(receipt),
72 Err(error) => {
73 if !matches!(error, Error::Network(_) | Error::Timeout(_)) {
74 return Err(error);
75 }
76
77 if attempt >= options.retry_policy.max_retries {
78 return Err(Error::MaxRetriesExceeded { count: attempt + 1 });
79 }
80
81 let delay = retry_delay(attempt, options.retry_policy.base_backoff);
82 if let Some(provider) = &options.retry_delay_provider {
83 (provider)(delay).await;
84 } else {
85 tokio::time::sleep(delay).await;
86 }
87 attempt += 1;
88 }
89 }
90 }
91 }
92
93 #[tracing::instrument(skip(self, options), fields(url = %url, destination = %destination.display(), retry_count = retry_count))]
94 async fn fetch_with_receipt_attempt(
95 &self,
96 url: &str,
97 destination: &Path,
98 options: &FetchOptions,
99 retry_count: u32,
100 ) -> Result<FetchReceipt> {
101 let start_time = std::time::Instant::now();
102 let mut performance_metrics = PerformanceMetrics::default();
103
104 let connecting_start = std::time::Instant::now();
105 self.report_progress(
106 options,
107 Progress {
108 phase: FetchPhase::Connecting,
109 bytes_downloaded: 0,
110 total_bytes: None,
111 retry_count,
112 performance_metrics: Some(performance_metrics.clone()),
113 },
114 );
115
116 let total_bytes = options.expected_bytes.or(self
117 .client
118 .head(url)
119 .await
120 .map_err(|e| Error::Network(e.to_string()))?);
121
122 let connecting_duration = connecting_start.elapsed();
123 performance_metrics.phase_timings.connecting_ms = connecting_duration.as_millis() as u64;
124 performance_metrics.connection_time_ms = Some(connecting_duration.as_millis() as u64);
125
126 self.report_progress(
127 options,
128 Progress {
129 phase: FetchPhase::Connecting,
130 bytes_downloaded: 0,
131 total_bytes,
132 retry_count,
133 performance_metrics: Some(performance_metrics.clone()),
134 },
135 );
136
137 let mut request_headers: Vec<(String, String)> = options.headers.iter().cloned().collect();
138 if let Some(offset) = options.resume_offset {
139 request_headers.push(("Range".to_string(), format!("bytes={offset}-")));
140 }
141
142 let staging_dir = self.workspace_root.join("staging");
143 let dest_dir = destination.parent().unwrap_or_else(|| Path::new("."));
144 let workspace = Workspace::new(&staging_dir, dest_dir)?;
145 let staging_file_path = workspace.path().join(
146 destination
147 .file_name()
148 .unwrap_or_else(|| std::ffi::OsStr::new("download")),
149 );
150
151 let mut stream = self
152 .client
153 .stream(url, &request_headers)
154 .await
155 .map_err(|e| Error::Network(e.to_string()))?;
156 let mut hasher = Sha256Hasher::new();
157
158 let downloading_start = std::time::Instant::now();
159 self.report_progress(
160 options,
161 Progress {
162 phase: FetchPhase::Downloading,
163 bytes_downloaded: options.resume_offset.unwrap_or(0),
164 total_bytes,
165 retry_count,
166 performance_metrics: Some(performance_metrics.clone()),
167 },
168 );
169
170 let mut bytes_downloaded = options.resume_offset.unwrap_or(0);
171 let mut last_progress_time = std::time::Instant::now();
172 let mut last_bytes_downloaded = bytes_downloaded;
173 use tokio::io::AsyncWriteExt;
174 let mut file = tokio::fs::File::create(&staging_file_path)
175 .await
176 .map_err(|e| Error::Network(e.to_string()))?;
177
178 while let Some(chunk_result) = stream.next().await {
179 let chunk = chunk_result.map_err(|e| Error::Network(e.to_string()))?;
180 hasher.update(&chunk);
181 file.write_all(&chunk)
182 .await
183 .map_err(|e| Error::Network(e.to_string()))?;
184 bytes_downloaded += chunk.len() as u64;
185
186 let now = std::time::Instant::now();
187 if now.duration_since(last_progress_time).as_millis() >= 100 {
188 let time_diff = now.duration_since(last_progress_time).as_secs_f64();
189 let bytes_diff = bytes_downloaded - last_bytes_downloaded;
190 if time_diff > 0.0 {
191 performance_metrics.current_rate_bps = Some(bytes_diff as f64 / time_diff);
192 }
193 last_progress_time = now;
194 last_bytes_downloaded = bytes_downloaded;
195 }
196
197 let total_time = start_time.elapsed().as_secs_f64();
198 if total_time > 0.0 {
199 performance_metrics.average_rate_bps = Some(bytes_downloaded as f64 / total_time);
200 }
201
202 self.report_progress(
203 options,
204 Progress {
205 phase: FetchPhase::Downloading,
206 bytes_downloaded,
207 total_bytes,
208 retry_count,
209 performance_metrics: Some(performance_metrics.clone()),
210 },
211 );
212 }
213
214 let downloading_duration = downloading_start.elapsed();
215 performance_metrics.phase_timings.downloading_ms = downloading_duration.as_millis() as u64;
216
217 let verifying_start = std::time::Instant::now();
218 self.report_progress(
219 options,
220 Progress {
221 phase: FetchPhase::Verifying,
222 bytes_downloaded,
223 total_bytes,
224 retry_count,
225 performance_metrics: Some(performance_metrics.clone()),
226 },
227 );
228
229 let actual_checksum = hasher.finalize();
230 if let Some(expected_checksum) = options.checksum
231 && actual_checksum != expected_checksum
232 {
233 return Err(Error::ChecksumMismatch {
234 expected: hex::encode(expected_checksum),
235 actual: hex::encode(actual_checksum),
236 });
237 }
238
239 let verifying_duration = verifying_start.elapsed();
240 performance_metrics.phase_timings.verifying_ms = verifying_duration.as_millis() as u64;
241
242 drop(file);
243
244 let committing_start = std::time::Instant::now();
245 self.report_progress(
246 options,
247 Progress {
248 phase: FetchPhase::Committing,
249 bytes_downloaded,
250 total_bytes,
251 retry_count,
252 performance_metrics: Some(performance_metrics.clone()),
253 },
254 );
255
256 workspace
257 .commit()
258 .map_err(|e| Error::Network(e.to_string()))?;
259
260 let committing_duration = committing_start.elapsed();
261 performance_metrics.phase_timings.committing_ms = committing_duration.as_millis() as u64;
262
263 self.report_progress(
264 options,
265 Progress {
266 phase: FetchPhase::Completed,
267 bytes_downloaded,
268 total_bytes,
269 retry_count,
270 performance_metrics: Some(performance_metrics),
271 },
272 );
273
274 Ok(FetchReceipt {
275 source: FetchSource::Url(url.to_string()),
276 destination: destination.to_path_buf(),
277 bytes_downloaded,
278 total_bytes,
279 sha256_hex: Some(hex::encode(actual_checksum)),
280 })
281 }
282
283 fn report_progress(&self, options: &FetchOptions, progress: Progress) {
285 if let Some(ref callback) = options.on_progress {
286 callback(&progress);
287 }
288 }
289
290 #[tracing::instrument(skip(self, source, options), fields(source = %source.url, destination = %destination.display()))]
292 pub async fn try_source(
293 &self,
294 source: &crate::DownloadSource,
295 destination: &Path,
296 options: &FetchOptions,
297 ) -> Result<FetchReceipt> {
298 let mut fetch_options = options.clone();
300 fetch_options.checksum = source.checksum;
301
302 self.fetch_with_receipt(&source.url, destination, fetch_options)
304 .await
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311 use crate::config::{FetchOptions, FetchPhase};
312 use crate::net::http::BoxStream;
313 use crate::progress::Progress;
314 use bytes::Bytes;
315 use std::path::PathBuf;
316 use std::sync::Arc;
317
318 #[derive(Debug)]
320 struct MockError(String);
321
322 impl std::fmt::Display for MockError {
323 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324 write!(f, "{}", self.0)
325 }
326 }
327
328 impl std::error::Error for MockError {}
329
330 struct MockHttpClient {
332 should_fail: bool,
333 content_length: Option<u64>,
334 }
335
336 impl MockHttpClient {
337 fn new() -> Self {
338 Self {
339 should_fail: false,
340 content_length: Some(1024),
341 }
342 }
343
344 fn with_error() -> Self {
345 Self {
346 should_fail: true,
347 content_length: None,
348 }
349 }
350
351 fn without_content_length() -> Self {
352 Self {
353 should_fail: false,
354 content_length: None,
355 }
356 }
357 }
358
359 impl HttpClient for MockHttpClient {
360 type Error = MockError;
361
362 async fn stream(
363 &self,
364 _url: &str,
365 _headers: &[(String, String)],
366 ) -> std::result::Result<
367 BoxStream<'static, std::result::Result<Bytes, Self::Error>>,
368 Self::Error,
369 > {
370 if self.should_fail {
371 Err(MockError("Stream failed".to_string()))
372 } else {
373 let stream = futures_util::stream::once(async { Ok(Bytes::from("test data")) });
374 Ok(Box::pin(stream)
375 as BoxStream<
376 'static,
377 std::result::Result<Bytes, Self::Error>,
378 >)
379 }
380 }
381
382 async fn head(&self, _url: &str) -> std::result::Result<Option<u64>, Self::Error> {
383 if self.should_fail {
384 Err(MockError("HEAD request failed".to_string()))
385 } else {
386 Ok(self.content_length)
387 }
388 }
389 }
390
391 #[tokio::test]
392 async fn test_fetcher_new() {
393 let client = MockHttpClient::new();
394 let workspace_root = "/tmp/test_workspace";
395 let fetcher = Fetcher::new(client, workspace_root);
396
397 assert_eq!(fetcher.workspace_root, PathBuf::from(workspace_root));
399 }
400
401 #[tokio::test]
402 async fn test_fetcher_head_success() {
403 let client = MockHttpClient::new();
404 let fetcher = Fetcher::new(client, "/tmp");
405
406 let result = fetcher.head("http://example.com").await;
407 assert!(result.is_ok());
408 assert_eq!(result.unwrap(), Some(1024));
409 }
410
411 #[tokio::test]
412 async fn test_fetcher_head_without_content_length() {
413 let client = MockHttpClient::without_content_length();
414 let fetcher = Fetcher::new(client, "/tmp");
415
416 let result = fetcher.head("http://example.com").await;
417 assert!(result.is_ok());
418 assert_eq!(result.unwrap(), None);
419 }
420
421 #[tokio::test]
422 async fn test_fetcher_head_error() {
423 let client = MockHttpClient::with_error();
424 let fetcher = Fetcher::new(client, "/tmp");
425
426 let result = fetcher.head("http://example.com").await;
427 assert!(result.is_err());
428 match result.unwrap_err() {
429 Error::Network(msg) => assert!(msg.contains("HEAD request failed")),
430 _ => panic!("Expected Network error"),
431 }
432 }
433
434 #[tokio::test]
435 async fn test_fetcher_fetch_success() {
436 let client = MockHttpClient::new();
437 let fetcher = Fetcher::new(client, "/tmp");
438
439 let url = "http://example.com";
440 let destination = PathBuf::from("/tmp/test_file");
441 let options = FetchOptions::default();
442
443 let result = fetcher.fetch_with_receipt(url, &destination, options).await;
445 assert!(result.is_ok() || result.is_err());
448 }
449
450 #[tokio::test]
451 async fn test_fetcher_fetch_with_progress_callback() {
452 let client = MockHttpClient::new();
453 let fetcher = Fetcher::new(client, "/tmp");
454
455 let url = "http://example.com";
456 let destination = PathBuf::from("/tmp/test_file");
457
458 let progress_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
459 let progress_called_clone = progress_called.clone();
460
461 let options = FetchOptions {
462 on_progress: Some(Arc::new(move |_progress| {
463 progress_called_clone.store(true, std::sync::atomic::Ordering::Relaxed);
464 })),
465 ..Default::default()
466 };
467
468 let _result = fetcher.fetch_with_receipt(url, &destination, options).await;
469 let _ = progress_called.load(std::sync::atomic::Ordering::Relaxed);
472 }
473
474 #[tokio::test]
475 async fn test_try_source() {
476 let client = MockHttpClient::new();
477 let fetcher = Fetcher::new(client, "/tmp");
478
479 let source = crate::DownloadSource::new("http://example.com".to_string());
480 let destination = PathBuf::from("/tmp/test_file");
481 let options = FetchOptions::default();
482
483 let result = fetcher.try_source(&source, &destination, &options).await;
484 assert!(result.is_ok() || result.is_err());
487 }
488
489 #[tokio::test]
490 async fn fetch_retries_with_explicit_retry_policy() {
491 use std::sync::atomic::{AtomicU32, Ordering};
492
493 struct AlwaysFailingHttpClient {
494 stream_calls: Arc<AtomicU32>,
495 }
496
497 impl HttpClient for AlwaysFailingHttpClient {
498 type Error = MockError;
499
500 async fn stream(
501 &self,
502 _url: &str,
503 _headers: &[(String, String)],
504 ) -> std::result::Result<
505 BoxStream<'static, std::result::Result<Bytes, Self::Error>>,
506 Self::Error,
507 > {
508 let _ = self.stream_calls.fetch_add(1, Ordering::SeqCst);
509 Err(MockError("stream always fails".to_string()))
510 }
511
512 async fn head(&self, _url: &str) -> std::result::Result<Option<u64>, Self::Error> {
513 Ok(Some(9))
514 }
515 }
516
517 let stream_calls = Arc::new(AtomicU32::new(0));
518 let client = AlwaysFailingHttpClient {
519 stream_calls: Arc::clone(&stream_calls),
520 };
521 let temp = tempfile::tempdir().unwrap();
522 let fetcher = Fetcher::new(client, temp.path());
523
524 let options = FetchOptions::default().retry_policy(crate::RetryPolicy {
525 max_retries: 1,
526 base_backoff: std::time::Duration::from_millis(1),
527 });
528
529 let error = fetcher
530 .fetch_with_receipt(
531 "http://example.com",
532 &temp.path().join("retry.bin"),
533 options,
534 )
535 .await
536 .unwrap_err();
537
538 assert!(matches!(error, Error::MaxRetriesExceeded { count: 2 }));
539 assert_eq!(stream_calls.load(Ordering::SeqCst), 2);
540 }
541
542 #[tokio::test]
543 async fn fetch_retries_can_use_custom_delay_provider() {
544 use std::sync::atomic::{AtomicU32, Ordering};
545
546 struct AlwaysFailingHttpClient;
547
548 impl HttpClient for AlwaysFailingHttpClient {
549 type Error = MockError;
550
551 async fn stream(
552 &self,
553 _url: &str,
554 _headers: &[(String, String)],
555 ) -> std::result::Result<
556 BoxStream<'static, std::result::Result<Bytes, Self::Error>>,
557 Self::Error,
558 > {
559 Err(MockError("stream always fails".to_string()))
560 }
561
562 async fn head(&self, _url: &str) -> std::result::Result<Option<u64>, Self::Error> {
563 Ok(Some(9))
564 }
565 }
566
567 let delay_calls = Arc::new(AtomicU32::new(0));
568 let delay_calls_for_provider = Arc::clone(&delay_calls);
569
570 let temp = tempfile::tempdir().unwrap();
571 let fetcher = Fetcher::new(AlwaysFailingHttpClient, temp.path());
572 let options = FetchOptions::default()
573 .retry_policy(crate::RetryPolicy {
574 max_retries: 2,
575 base_backoff: std::time::Duration::from_millis(1),
576 })
577 .retry_delay_provider(Arc::new(move |_delay| {
578 let delay_calls_for_provider = Arc::clone(&delay_calls_for_provider);
579 Box::pin(async move {
580 delay_calls_for_provider.fetch_add(1, Ordering::SeqCst);
581 })
582 }));
583
584 let error = fetcher
585 .fetch_with_receipt(
586 "http://example.com",
587 &temp.path().join("retry-custom-delay.bin"),
588 options,
589 )
590 .await
591 .unwrap_err();
592
593 assert!(matches!(error, Error::MaxRetriesExceeded { count: 3 }));
594 assert_eq!(delay_calls.load(Ordering::SeqCst), 2);
595 }
596
597 #[tokio::test]
598 async fn fetch_applies_resume_offset_as_range_header() {
599 use std::sync::Mutex;
600
601 struct HeaderCaptureHttpClient {
602 seen_headers: Arc<Mutex<Vec<(String, String)>>>,
603 }
604
605 impl HttpClient for HeaderCaptureHttpClient {
606 type Error = MockError;
607
608 async fn stream(
609 &self,
610 _url: &str,
611 headers: &[(String, String)],
612 ) -> std::result::Result<
613 BoxStream<'static, std::result::Result<Bytes, Self::Error>>,
614 Self::Error,
615 > {
616 *self.seen_headers.lock().unwrap() = headers.to_vec();
617 Err(MockError("fail after header capture".to_string()))
618 }
619
620 async fn head(&self, _url: &str) -> std::result::Result<Option<u64>, Self::Error> {
621 Ok(Some(256))
622 }
623 }
624
625 let seen_headers = Arc::new(Mutex::new(Vec::<(String, String)>::new()));
626 let client = HeaderCaptureHttpClient {
627 seen_headers: Arc::clone(&seen_headers),
628 };
629 let temp = tempfile::tempdir().unwrap();
630 let fetcher = Fetcher::new(client, temp.path());
631
632 let options = FetchOptions::default()
633 .retry_policy(crate::RetryPolicy {
634 max_retries: 0,
635 base_backoff: std::time::Duration::from_millis(1),
636 })
637 .resume_offset(Some(128))
638 .expected_bytes(Some(256));
639
640 let error = fetcher
641 .fetch_with_receipt(
642 "http://example.com",
643 &temp.path().join("resume.bin"),
644 options,
645 )
646 .await
647 .unwrap_err();
648
649 assert!(matches!(error, Error::MaxRetriesExceeded { count: 1 }));
650 let headers = seen_headers.lock().unwrap().clone();
651 assert!(
652 headers
653 .iter()
654 .any(|(k, v)| k == "Range" && v == "bytes=128-")
655 );
656 }
657
658 #[test]
659 fn test_report_progress_without_callback() {
660 let client = MockHttpClient::new();
661 let fetcher = Fetcher::new(client, "/tmp");
662
663 let options = FetchOptions::default();
664 let progress = Progress {
665 phase: FetchPhase::Connecting,
666 bytes_downloaded: 0,
667 total_bytes: None,
668 retry_count: 0,
669 performance_metrics: None,
670 };
671
672 fetcher.report_progress(&options, progress);
674 }
675
676 #[test]
677 fn test_report_progress_with_callback() {
678 let client = MockHttpClient::new();
679 let fetcher = Fetcher::new(client, "/tmp");
680
681 let callback_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
682 let callback_called_clone = callback_called.clone();
683
684 let options = FetchOptions {
685 on_progress: Some(Arc::new(move |_progress| {
686 callback_called_clone.store(true, std::sync::atomic::Ordering::Relaxed);
687 })),
688 ..Default::default()
689 };
690
691 let progress = Progress {
692 phase: FetchPhase::Connecting,
693 bytes_downloaded: 0,
694 total_bytes: None,
695 retry_count: 0,
696 performance_metrics: None,
697 };
698
699 fetcher.report_progress(&options, progress);
700 assert!(callback_called.load(std::sync::atomic::Ordering::Relaxed));
701 }
702}