1use std::path::{Path, PathBuf};
7use std::sync::Arc;
8use std::time::{SystemTime, UNIX_EPOCH};
9
10use serde::{Deserialize, Serialize};
11use tokio::fs;
12use tokio::io::AsyncWriteExt;
13
14use crate::config::{FetchOptions, FetchPhase};
15use crate::error::{Error, Result};
16use crate::fetch::fetcher::Fetcher;
17use crate::net::http::HttpClient;
18use crate::progress::Progress;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct DownloadCheckpoint {
23 pub url: String,
25 pub destination: PathBuf,
27 pub total_bytes: Option<u64>,
29 pub downloaded_bytes: u64,
31 pub partial_checksum: Option<String>,
33 pub last_update: u64,
35}
36
37impl DownloadCheckpoint {
38 pub fn new(url: String, destination: PathBuf, total_bytes: Option<u64>) -> Self {
40 Self {
41 url,
42 destination,
43 total_bytes,
44 downloaded_bytes: 0,
45 partial_checksum: None,
46 last_update: SystemTime::now()
47 .duration_since(UNIX_EPOCH)
48 .unwrap_or_default()
49 .as_secs(),
50 }
51 }
52
53 pub fn update_progress(&mut self, downloaded_bytes: u64) {
55 self.downloaded_bytes = downloaded_bytes;
56 self.last_update = SystemTime::now()
57 .duration_since(UNIX_EPOCH)
58 .unwrap_or_default()
59 .as_secs();
60 }
61
62 pub fn can_resume(&self) -> bool {
64 self.downloaded_bytes > 0
65 }
66
67 pub fn range_header(&self) -> String {
69 format!("bytes={}-", self.downloaded_bytes)
70 }
71}
72
73pub struct ResumableFetcher<C: HttpClient> {
75 base_fetcher: Fetcher<C>,
76 checkpoint_dir: PathBuf,
77}
78
79impl<C: HttpClient + 'static> ResumableFetcher<C> {
80 pub fn new(client: C, workspace_root: impl Into<PathBuf>) -> Self {
82 let workspace_root = workspace_root.into();
83 Self {
84 base_fetcher: Fetcher::new(client, workspace_root.clone()),
85 checkpoint_dir: workspace_root.join(".checkpoints"),
86 }
87 }
88
89 pub async fn fetch_resumable(
91 &self,
92 url: &str,
93 destination: &Path,
94 options: FetchOptions,
95 ) -> Result<PathBuf> {
96 fs::create_dir_all(&self.checkpoint_dir)
98 .await
99 .map_err(|e| Error::Network(e.to_string()))?;
100
101 let checkpoint_path = self.checkpoint_path(url, destination);
102
103 if let Ok(checkpoint) = self.load_checkpoint(&checkpoint_path).await
105 && checkpoint.can_resume()
106 {
107 return self
108 .resume_download(&checkpoint, &checkpoint_path, options)
109 .await;
110 }
111
112 self.start_new_download(url, destination, &checkpoint_path, options)
114 .await
115 }
116
117 async fn start_new_download(
119 &self,
120 url: &str,
121 destination: &Path,
122 checkpoint_path: &Path,
123 options: FetchOptions,
124 ) -> Result<PathBuf> {
125 let total_bytes = self
127 .base_fetcher
128 .head(url)
129 .await
130 .map_err(|e| Error::Network(e.to_string()))?;
131
132 let checkpoint =
134 DownloadCheckpoint::new(url.to_string(), destination.to_path_buf(), total_bytes);
135
136 self.save_checkpoint(&checkpoint, checkpoint_path).await?;
138
139 let checkpoint_path_clone = checkpoint_path.to_path_buf();
141 let checkpoint_dir = self.checkpoint_dir.clone();
142 let url_clone = url.to_string();
143 let destination_clone = destination.to_path_buf();
144
145 let mut options_with_checkpoint = options.clone();
146 let original_callback = options_with_checkpoint.on_progress.clone();
147
148 options_with_checkpoint.on_progress = Some(Arc::new(move |progress: &Progress| {
149 if progress.phase == FetchPhase::Downloading {
151 let mut checkpoint = DownloadCheckpoint::new(
153 url_clone.clone(),
154 destination_clone.clone(),
155 total_bytes,
156 );
157 checkpoint.update_progress(progress.bytes_downloaded);
158
159 let checkpoint_path = checkpoint_path_clone.clone();
161 let checkpoint_dir = checkpoint_dir.clone();
162 tokio::spawn(async move {
163 let _ = Self::save_checkpoint_static(
164 &checkpoint,
165 &checkpoint_path,
166 &checkpoint_dir,
167 )
168 .await;
169 });
170 }
171
172 if let Some(ref callback) = original_callback {
174 callback(progress);
175 }
176 }));
177
178 let result = self
180 .base_fetcher
181 .fetch_with_receipt(url, destination, options_with_checkpoint)
182 .await;
183
184 match result {
185 Ok(receipt) => {
186 let _ = fs::remove_file(checkpoint_path).await;
188 Ok(receipt.destination)
189 }
190 Err(e) => {
191 Err(e)
193 }
194 }
195 }
196
197 async fn resume_download(
199 &self,
200 checkpoint: &DownloadCheckpoint,
201 checkpoint_path: &Path,
202 options: FetchOptions,
203 ) -> Result<PathBuf> {
204 if checkpoint.destination.exists() {
206 let metadata = fs::metadata(&checkpoint.destination)
207 .await
208 .map_err(|e| Error::Network(e.to_string()))?;
209 let current_size = metadata.len();
210
211 if current_size != checkpoint.downloaded_bytes {
212 let _ = fs::remove_file(&checkpoint.destination).await;
214 let _ = fs::remove_file(checkpoint_path).await;
215 return self
216 .start_new_download(
217 &checkpoint.url,
218 &checkpoint.destination,
219 checkpoint_path,
220 options,
221 )
222 .await;
223 }
224 }
225
226 let mut resume_options = options.clone();
228 resume_options.resume_offset = Some(checkpoint.downloaded_bytes);
229
230 let checkpoint_path_clone = checkpoint_path.to_path_buf();
232 let checkpoint_dir = self.checkpoint_dir.clone();
233 let initial_bytes = checkpoint.downloaded_bytes;
234 let original_total_bytes = checkpoint.total_bytes;
235 let checkpoint_url = checkpoint.url.clone();
236 let checkpoint_destination = checkpoint.destination.clone();
237
238 let original_callback = resume_options.on_progress.clone();
239 resume_options.on_progress = Some(Arc::new(move |progress: &Progress| {
240 if progress.phase == FetchPhase::Downloading {
242 let total_downloaded = initial_bytes + progress.bytes_downloaded;
243
244 let mut new_checkpoint = DownloadCheckpoint::new(
246 checkpoint_url.clone(),
247 checkpoint_destination.clone(),
248 original_total_bytes,
249 );
250 new_checkpoint.update_progress(total_downloaded);
251
252 let checkpoint_path = checkpoint_path_clone.clone();
254 let checkpoint_dir = checkpoint_dir.clone();
255 tokio::spawn(async move {
256 let _ = Self::save_checkpoint_static(
257 &new_checkpoint,
258 &checkpoint_path,
259 &checkpoint_dir,
260 )
261 .await;
262 });
263 }
264
265 if let Some(ref callback) = original_callback {
267 callback(progress);
268 }
269 }));
270
271 let result = self
273 .base_fetcher
274 .fetch_with_receipt(&checkpoint.url, &checkpoint.destination, resume_options)
275 .await;
276
277 match result {
278 Ok(receipt) => {
279 let _ = fs::remove_file(checkpoint_path).await;
281 Ok(receipt.destination)
282 }
283 Err(e) => {
284 Err(e)
286 }
287 }
288 }
289
290 fn checkpoint_path(&self, url: &str, destination: &Path) -> PathBuf {
292 use std::collections::hash_map::DefaultHasher;
293 use std::hash::{Hash, Hasher};
294
295 let mut hasher = DefaultHasher::new();
297 url.hash(&mut hasher);
298 destination.hash(&mut hasher);
299 let hash = hasher.finish();
300
301 self.checkpoint_dir
302 .join(format!("checkpoint_{:016x}.json", hash))
303 }
304
305 async fn load_checkpoint(&self, path: &Path) -> Result<DownloadCheckpoint> {
307 let content = fs::read_to_string(path)
308 .await
309 .map_err(|e| Error::Network(e.to_string()))?;
310
311 serde_json::from_str(&content)
312 .map_err(|e| Error::InvalidState(format!("Invalid checkpoint: {}", e)))
313 }
314
315 async fn save_checkpoint(&self, checkpoint: &DownloadCheckpoint, path: &Path) -> Result<()> {
317 Self::save_checkpoint_static(checkpoint, path, &self.checkpoint_dir).await
318 }
319
320 async fn save_checkpoint_static(
322 checkpoint: &DownloadCheckpoint,
323 path: &Path,
324 checkpoint_dir: &Path,
325 ) -> Result<()> {
326 fs::create_dir_all(checkpoint_dir)
328 .await
329 .map_err(|e| Error::Network(e.to_string()))?;
330
331 let content = serde_json::to_string_pretty(checkpoint)
333 .map_err(|e| Error::InvalidState(format!("Failed to serialize checkpoint: {}", e)))?;
334
335 let temp_path = path.with_extension("tmp");
337 {
338 let mut file: tokio::fs::File = fs::File::create(&temp_path)
339 .await
340 .map_err(|e| Error::Network(e.to_string()))?;
341 file.write_all(content.as_bytes())
342 .await
343 .map_err(|e| Error::Network(e.to_string()))?;
344 file.sync_all()
345 .await
346 .map_err(|e| Error::Network(e.to_string()))?;
347 }
348
349 fs::rename(&temp_path, path)
351 .await
352 .map_err(|e| Error::Network(e.to_string()))?;
353
354 Ok(())
355 }
356
357 pub async fn cleanup_old_checkpoints(&self, max_age_seconds: u64) -> Result<usize> {
359 let mut cleaned = 0;
360 let cutoff = SystemTime::now()
361 .duration_since(UNIX_EPOCH)
362 .unwrap_or_default()
363 .as_secs()
364 - max_age_seconds;
365
366 let mut entries = fs::read_dir(&self.checkpoint_dir)
367 .await
368 .map_err(|e| Error::Network(e.to_string()))?;
369
370 while let Some(entry) = entries
371 .next_entry()
372 .await
373 .map_err(|e| Error::Network(e.to_string()))?
374 {
375 let path = entry.path();
376
377 if path.extension().and_then(|s| s.to_str()) == Some("json") {
378 match self.load_checkpoint(&path).await {
379 Ok(checkpoint) => {
380 if checkpoint.last_update < cutoff {
381 let _ = fs::remove_file(&path).await;
382 cleaned += 1;
383 }
384 }
385 Err(_) => {
386 let _ = fs::remove_file(&path).await;
388 cleaned += 1;
389 }
390 }
391 }
392 }
393
394 Ok(cleaned)
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use crate::net::http::BoxStream;
402 use bytes::Bytes;
403 use tempfile::TempDir;
404
405 #[derive(Debug)]
407 struct MockClient;
408
409 impl MockClient {
410 fn new() -> Self {
411 Self
412 }
413 }
414
415 #[derive(Debug)]
416 struct MockError(String);
417
418 impl std::fmt::Display for MockError {
419 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
420 write!(f, "{}", self.0)
421 }
422 }
423
424 impl std::error::Error for MockError {}
425
426 impl HttpClient for MockClient {
427 type Error = MockError;
428
429 async fn stream(
430 &self,
431 _url: &str,
432 _headers: &[(String, String)],
433 ) -> std::result::Result<
434 BoxStream<'static, std::result::Result<Bytes, Self::Error>>,
435 Self::Error,
436 > {
437 let empty: BoxStream<'static, std::result::Result<Bytes, Self::Error>> =
438 Box::pin(futures_util::stream::empty());
439 Ok(empty)
440 }
441
442 async fn head(&self, _url: &str) -> std::result::Result<Option<u64>, Self::Error> {
443 Ok(Some(1024))
444 }
445 }
446
447 #[test]
448 fn test_download_checkpoint() {
449 let mut checkpoint = DownloadCheckpoint::new(
450 "https://example.com/file.txt".to_string(),
451 PathBuf::from("/tmp/file.txt"),
452 Some(1024),
453 );
454
455 assert_eq!(checkpoint.downloaded_bytes, 0);
456 assert!(!checkpoint.can_resume());
457 assert_eq!(checkpoint.range_header(), "bytes=0-");
458
459 checkpoint.update_progress(512);
460 assert_eq!(checkpoint.downloaded_bytes, 512);
461 assert!(checkpoint.can_resume());
462 assert_eq!(checkpoint.range_header(), "bytes=512-");
463 }
464
465 #[tokio::test]
466 async fn test_checkpoint_save_load() {
467 let temp_dir = TempDir::new().unwrap();
468 let checkpoint_path = temp_dir.path().join("checkpoint.json");
469
470 let mut original = DownloadCheckpoint::new(
471 "https://example.com/file.txt".to_string(),
472 PathBuf::from("/tmp/file.txt"),
473 Some(1024),
474 );
475
476 assert_eq!(original.downloaded_bytes, 0);
477 assert!(!original.can_resume());
478 assert_eq!(original.range_header(), "bytes=0-");
479
480 original.update_progress(512);
481
482 let fetcher: ResumableFetcher<MockClient> =
484 ResumableFetcher::new(MockClient::new(), temp_dir.path());
485 fetcher
486 .save_checkpoint(&original, &checkpoint_path)
487 .await
488 .unwrap();
489
490 let loaded: DownloadCheckpoint = fetcher.load_checkpoint(&checkpoint_path).await.unwrap();
492
493 assert_eq!(loaded.url, original.url);
494 assert_eq!(loaded.destination, original.destination);
495 assert_eq!(loaded.total_bytes, original.total_bytes);
496 assert_eq!(loaded.downloaded_bytes, original.downloaded_bytes);
497 }
498
499 #[tokio::test]
500 async fn test_cleanup_old_checkpoints() {
501 let temp_dir = TempDir::new().unwrap();
502 let fetcher = ResumableFetcher::<MockClient>::new(MockClient::new(), temp_dir.path());
503
504 let mut checkpoint1 = DownloadCheckpoint::new(
506 "https://example.com/file1.txt".to_string(),
507 PathBuf::from("/tmp/file1.txt"),
508 Some(1024),
509 );
510 let mut checkpoint2 = DownloadCheckpoint::new(
511 "https://example.com/file2.txt".to_string(),
512 PathBuf::from("/tmp/file2.txt"),
513 Some(1024),
514 );
515
516 let old_timestamp = SystemTime::now()
518 .duration_since(UNIX_EPOCH)
519 .unwrap_or_default()
520 .as_secs()
521 - 10;
522 checkpoint1.last_update = old_timestamp;
523 checkpoint2.last_update = old_timestamp;
524
525 let path1 =
526 fetcher.checkpoint_path("https://example.com/file1.txt", Path::new("/tmp/file1.txt"));
527 let path2 =
528 fetcher.checkpoint_path("https://example.com/file2.txt", Path::new("/tmp/file2.txt"));
529
530 fetcher.save_checkpoint(&checkpoint1, &path1).await.unwrap();
531 fetcher.save_checkpoint(&checkpoint2, &path2).await.unwrap();
532
533 let cleaned = fetcher.cleanup_old_checkpoints(5).await.unwrap();
535 assert_eq!(cleaned, 2);
536
537 assert!(!path1.exists());
539 assert!(!path2.exists());
540 }
541}