Skip to main content

pulith_fetch/fetch/
resumable.rs

1//! Resumable download functionality.
2//!
3//! This module provides the ability to resume interrupted downloads
4//! using HTTP Range requests and state persistence.
5
6use std::path::{Path, PathBuf};
7use std::sync::Arc;
8use std::time::{SystemTime, UNIX_EPOCH};
9
10use serde::{Deserialize, Serialize};
11use tokio::fs;
12use tokio::io::AsyncWriteExt;
13
14use crate::config::{FetchOptions, FetchPhase};
15use crate::error::{Error, Result};
16use crate::fetch::fetcher::Fetcher;
17use crate::net::http::HttpClient;
18use crate::progress::Progress;
19
20/// Checkpoint data for resumable downloads.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct DownloadCheckpoint {
23    /// URL being downloaded
24    pub url: String,
25    /// Destination path
26    pub destination: PathBuf,
27    /// Total bytes expected (from Content-Length)
28    pub total_bytes: Option<u64>,
29    /// Bytes already downloaded
30    pub downloaded_bytes: u64,
31    /// Checksum of downloaded data (if available)
32    pub partial_checksum: Option<String>,
33    /// Timestamp of last progress
34    pub last_update: u64,
35}
36
37impl DownloadCheckpoint {
38    /// Create a new checkpoint.
39    pub fn new(url: String, destination: PathBuf, total_bytes: Option<u64>) -> Self {
40        Self {
41            url,
42            destination,
43            total_bytes,
44            downloaded_bytes: 0,
45            partial_checksum: None,
46            last_update: SystemTime::now()
47                .duration_since(UNIX_EPOCH)
48                .unwrap_or_default()
49                .as_secs(),
50        }
51    }
52
53    /// Update checkpoint with new progress.
54    pub fn update_progress(&mut self, downloaded_bytes: u64) {
55        self.downloaded_bytes = downloaded_bytes;
56        self.last_update = SystemTime::now()
57            .duration_since(UNIX_EPOCH)
58            .unwrap_or_default()
59            .as_secs();
60    }
61
62    /// Check if download can be resumed.
63    pub fn can_resume(&self) -> bool {
64        self.downloaded_bytes > 0
65    }
66
67    /// Get the Range header value for resuming.
68    pub fn range_header(&self) -> String {
69        format!("bytes={}-", self.downloaded_bytes)
70    }
71}
72
73/// Resumable fetcher implementation.
74pub struct ResumableFetcher<C: HttpClient> {
75    base_fetcher: Fetcher<C>,
76    checkpoint_dir: PathBuf,
77}
78
79impl<C: HttpClient + 'static> ResumableFetcher<C> {
80    /// Create a new resumable fetcher.
81    pub fn new(client: C, workspace_root: impl Into<PathBuf>) -> Self {
82        let workspace_root = workspace_root.into();
83        Self {
84            base_fetcher: Fetcher::new(client, workspace_root.clone()),
85            checkpoint_dir: workspace_root.join(".checkpoints"),
86        }
87    }
88
89    /// Fetch a file with resumable support.
90    pub async fn fetch_resumable(
91        &self,
92        url: &str,
93        destination: &Path,
94        options: FetchOptions,
95    ) -> Result<PathBuf> {
96        // Ensure checkpoint directory exists
97        fs::create_dir_all(&self.checkpoint_dir)
98            .await
99            .map_err(|e| Error::Network(e.to_string()))?;
100
101        let checkpoint_path = self.checkpoint_path(url, destination);
102
103        // Try to load existing checkpoint
104        if let Ok(checkpoint) = self.load_checkpoint(&checkpoint_path).await
105            && checkpoint.can_resume()
106        {
107            return self
108                .resume_download(&checkpoint, &checkpoint_path, options)
109                .await;
110        }
111
112        // Start new download
113        self.start_new_download(url, destination, &checkpoint_path, options)
114            .await
115    }
116
117    /// Start a new download with checkpoint tracking.
118    async fn start_new_download(
119        &self,
120        url: &str,
121        destination: &Path,
122        checkpoint_path: &Path,
123        options: FetchOptions,
124    ) -> Result<PathBuf> {
125        // Get total bytes from HEAD request
126        let total_bytes = self
127            .base_fetcher
128            .head(url)
129            .await
130            .map_err(|e| Error::Network(e.to_string()))?;
131
132        // Create initial checkpoint
133        let checkpoint =
134            DownloadCheckpoint::new(url.to_string(), destination.to_path_buf(), total_bytes);
135
136        // Save initial checkpoint
137        self.save_checkpoint(&checkpoint, checkpoint_path).await?;
138
139        // Set up progress callback to update checkpoint
140        let checkpoint_path_clone = checkpoint_path.to_path_buf();
141        let checkpoint_dir = self.checkpoint_dir.clone();
142        let url_clone = url.to_string();
143        let destination_clone = destination.to_path_buf();
144
145        let mut options_with_checkpoint = options.clone();
146        let original_callback = options_with_checkpoint.on_progress.clone();
147
148        options_with_checkpoint.on_progress = Some(Arc::new(move |progress: &Progress| {
149            // Update checkpoint on progress
150            if progress.phase == FetchPhase::Downloading {
151                // Create new checkpoint with updated progress
152                let mut checkpoint = DownloadCheckpoint::new(
153                    url_clone.clone(),
154                    destination_clone.clone(),
155                    total_bytes,
156                );
157                checkpoint.update_progress(progress.bytes_downloaded);
158
159                // Save checkpoint asynchronously (fire and forget)
160                let checkpoint_path = checkpoint_path_clone.clone();
161                let checkpoint_dir = checkpoint_dir.clone();
162                tokio::spawn(async move {
163                    let _ = Self::save_checkpoint_static(
164                        &checkpoint,
165                        &checkpoint_path,
166                        &checkpoint_dir,
167                    )
168                    .await;
169                });
170            }
171
172            // Call original callback if present
173            if let Some(ref callback) = original_callback {
174                callback(progress);
175            }
176        }));
177
178        // Perform the download
179        let result = self
180            .base_fetcher
181            .fetch_with_receipt(url, destination, options_with_checkpoint)
182            .await;
183
184        match result {
185            Ok(receipt) => {
186                // Download successful, remove checkpoint
187                let _ = fs::remove_file(checkpoint_path).await;
188                Ok(receipt.destination)
189            }
190            Err(e) => {
191                // Download failed, checkpoint remains for resuming
192                Err(e)
193            }
194        }
195    }
196
197    /// Resume an interrupted download.
198    async fn resume_download(
199        &self,
200        checkpoint: &DownloadCheckpoint,
201        checkpoint_path: &Path,
202        options: FetchOptions,
203    ) -> Result<PathBuf> {
204        // Check if destination file exists and has expected size
205        if checkpoint.destination.exists() {
206            let metadata = fs::metadata(&checkpoint.destination)
207                .await
208                .map_err(|e| Error::Network(e.to_string()))?;
209            let current_size = metadata.len();
210
211            if current_size != checkpoint.downloaded_bytes {
212                // File size mismatch, start over
213                let _ = fs::remove_file(&checkpoint.destination).await;
214                let _ = fs::remove_file(checkpoint_path).await;
215                return self
216                    .start_new_download(
217                        &checkpoint.url,
218                        &checkpoint.destination,
219                        checkpoint_path,
220                        options,
221                    )
222                    .await;
223            }
224        }
225
226        // Create fetch options with explicit resume offset.
227        let mut resume_options = options.clone();
228        resume_options.resume_offset = Some(checkpoint.downloaded_bytes);
229
230        // Set up progress callback for resumed download
231        let checkpoint_path_clone = checkpoint_path.to_path_buf();
232        let checkpoint_dir = self.checkpoint_dir.clone();
233        let initial_bytes = checkpoint.downloaded_bytes;
234        let original_total_bytes = checkpoint.total_bytes;
235        let checkpoint_url = checkpoint.url.clone();
236        let checkpoint_destination = checkpoint.destination.clone();
237
238        let original_callback = resume_options.on_progress.clone();
239        resume_options.on_progress = Some(Arc::new(move |progress: &Progress| {
240            // Update checkpoint with total progress (initial + new)
241            if progress.phase == FetchPhase::Downloading {
242                let total_downloaded = initial_bytes + progress.bytes_downloaded;
243
244                // Create new checkpoint with updated progress
245                let mut new_checkpoint = DownloadCheckpoint::new(
246                    checkpoint_url.clone(),
247                    checkpoint_destination.clone(),
248                    original_total_bytes,
249                );
250                new_checkpoint.update_progress(total_downloaded);
251
252                // Save checkpoint asynchronously
253                let checkpoint_path = checkpoint_path_clone.clone();
254                let checkpoint_dir = checkpoint_dir.clone();
255                tokio::spawn(async move {
256                    let _ = Self::save_checkpoint_static(
257                        &new_checkpoint,
258                        &checkpoint_path,
259                        &checkpoint_dir,
260                    )
261                    .await;
262                });
263            }
264
265            // Call original callback if present
266            if let Some(ref callback) = original_callback {
267                callback(progress);
268            }
269        }));
270
271        // Resume the download
272        let result = self
273            .base_fetcher
274            .fetch_with_receipt(&checkpoint.url, &checkpoint.destination, resume_options)
275            .await;
276
277        match result {
278            Ok(receipt) => {
279                // Download successful, remove checkpoint
280                let _ = fs::remove_file(checkpoint_path).await;
281                Ok(receipt.destination)
282            }
283            Err(e) => {
284                // Download failed, checkpoint remains
285                Err(e)
286            }
287        }
288    }
289
290    /// Get the checkpoint file path for a download.
291    fn checkpoint_path(&self, url: &str, destination: &Path) -> PathBuf {
292        use std::collections::hash_map::DefaultHasher;
293        use std::hash::{Hash, Hasher};
294
295        // Create a unique filename from URL and destination
296        let mut hasher = DefaultHasher::new();
297        url.hash(&mut hasher);
298        destination.hash(&mut hasher);
299        let hash = hasher.finish();
300
301        self.checkpoint_dir
302            .join(format!("checkpoint_{:016x}.json", hash))
303    }
304
305    /// Load a checkpoint from file.
306    async fn load_checkpoint(&self, path: &Path) -> Result<DownloadCheckpoint> {
307        let content = fs::read_to_string(path)
308            .await
309            .map_err(|e| Error::Network(e.to_string()))?;
310
311        serde_json::from_str(&content)
312            .map_err(|e| Error::InvalidState(format!("Invalid checkpoint: {}", e)))
313    }
314
315    /// Save a checkpoint to file.
316    async fn save_checkpoint(&self, checkpoint: &DownloadCheckpoint, path: &Path) -> Result<()> {
317        Self::save_checkpoint_static(checkpoint, path, &self.checkpoint_dir).await
318    }
319
320    /// Static version of save_checkpoint for use in closures.
321    async fn save_checkpoint_static(
322        checkpoint: &DownloadCheckpoint,
323        path: &Path,
324        checkpoint_dir: &Path,
325    ) -> Result<()> {
326        // Ensure directory exists
327        fs::create_dir_all(checkpoint_dir)
328            .await
329            .map_err(|e| Error::Network(e.to_string()))?;
330
331        // Serialize checkpoint
332        let content = serde_json::to_string_pretty(checkpoint)
333            .map_err(|e| Error::InvalidState(format!("Failed to serialize checkpoint: {}", e)))?;
334
335        // Write to temporary file first
336        let temp_path = path.with_extension("tmp");
337        {
338            let mut file: tokio::fs::File = fs::File::create(&temp_path)
339                .await
340                .map_err(|e| Error::Network(e.to_string()))?;
341            file.write_all(content.as_bytes())
342                .await
343                .map_err(|e| Error::Network(e.to_string()))?;
344            file.sync_all()
345                .await
346                .map_err(|e| Error::Network(e.to_string()))?;
347        }
348
349        // Atomic rename
350        fs::rename(&temp_path, path)
351            .await
352            .map_err(|e| Error::Network(e.to_string()))?;
353
354        Ok(())
355    }
356
357    /// Clean up old checkpoints.
358    pub async fn cleanup_old_checkpoints(&self, max_age_seconds: u64) -> Result<usize> {
359        let mut cleaned = 0;
360        let cutoff = SystemTime::now()
361            .duration_since(UNIX_EPOCH)
362            .unwrap_or_default()
363            .as_secs()
364            - max_age_seconds;
365
366        let mut entries = fs::read_dir(&self.checkpoint_dir)
367            .await
368            .map_err(|e| Error::Network(e.to_string()))?;
369
370        while let Some(entry) = entries
371            .next_entry()
372            .await
373            .map_err(|e| Error::Network(e.to_string()))?
374        {
375            let path = entry.path();
376
377            if path.extension().and_then(|s| s.to_str()) == Some("json") {
378                match self.load_checkpoint(&path).await {
379                    Ok(checkpoint) => {
380                        if checkpoint.last_update < cutoff {
381                            let _ = fs::remove_file(&path).await;
382                            cleaned += 1;
383                        }
384                    }
385                    Err(_) => {
386                        // Invalid checkpoint, remove it
387                        let _ = fs::remove_file(&path).await;
388                        cleaned += 1;
389                    }
390                }
391            }
392        }
393
394        Ok(cleaned)
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401    use crate::net::http::BoxStream;
402    use bytes::Bytes;
403    use tempfile::TempDir;
404
405    /// Simple mock HTTP client for testing
406    #[derive(Debug)]
407    struct MockClient;
408
409    impl MockClient {
410        fn new() -> Self {
411            Self
412        }
413    }
414
415    #[derive(Debug)]
416    struct MockError(String);
417
418    impl std::fmt::Display for MockError {
419        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
420            write!(f, "{}", self.0)
421        }
422    }
423
424    impl std::error::Error for MockError {}
425
426    impl HttpClient for MockClient {
427        type Error = MockError;
428
429        async fn stream(
430            &self,
431            _url: &str,
432            _headers: &[(String, String)],
433        ) -> std::result::Result<
434            BoxStream<'static, std::result::Result<Bytes, Self::Error>>,
435            Self::Error,
436        > {
437            let empty: BoxStream<'static, std::result::Result<Bytes, Self::Error>> =
438                Box::pin(futures_util::stream::empty());
439            Ok(empty)
440        }
441
442        async fn head(&self, _url: &str) -> std::result::Result<Option<u64>, Self::Error> {
443            Ok(Some(1024))
444        }
445    }
446
447    #[test]
448    fn test_download_checkpoint() {
449        let mut checkpoint = DownloadCheckpoint::new(
450            "https://example.com/file.txt".to_string(),
451            PathBuf::from("/tmp/file.txt"),
452            Some(1024),
453        );
454
455        assert_eq!(checkpoint.downloaded_bytes, 0);
456        assert!(!checkpoint.can_resume());
457        assert_eq!(checkpoint.range_header(), "bytes=0-");
458
459        checkpoint.update_progress(512);
460        assert_eq!(checkpoint.downloaded_bytes, 512);
461        assert!(checkpoint.can_resume());
462        assert_eq!(checkpoint.range_header(), "bytes=512-");
463    }
464
465    #[tokio::test]
466    async fn test_checkpoint_save_load() {
467        let temp_dir = TempDir::new().unwrap();
468        let checkpoint_path = temp_dir.path().join("checkpoint.json");
469
470        let mut original = DownloadCheckpoint::new(
471            "https://example.com/file.txt".to_string(),
472            PathBuf::from("/tmp/file.txt"),
473            Some(1024),
474        );
475
476        assert_eq!(original.downloaded_bytes, 0);
477        assert!(!original.can_resume());
478        assert_eq!(original.range_header(), "bytes=0-");
479
480        original.update_progress(512);
481
482        // Save checkpoint
483        let fetcher: ResumableFetcher<MockClient> =
484            ResumableFetcher::new(MockClient::new(), temp_dir.path());
485        fetcher
486            .save_checkpoint(&original, &checkpoint_path)
487            .await
488            .unwrap();
489
490        // Load checkpoint
491        let loaded: DownloadCheckpoint = fetcher.load_checkpoint(&checkpoint_path).await.unwrap();
492
493        assert_eq!(loaded.url, original.url);
494        assert_eq!(loaded.destination, original.destination);
495        assert_eq!(loaded.total_bytes, original.total_bytes);
496        assert_eq!(loaded.downloaded_bytes, original.downloaded_bytes);
497    }
498
499    #[tokio::test]
500    async fn test_cleanup_old_checkpoints() {
501        let temp_dir = TempDir::new().unwrap();
502        let fetcher = ResumableFetcher::<MockClient>::new(MockClient::new(), temp_dir.path());
503
504        // Create some checkpoints with old timestamps
505        let mut checkpoint1 = DownloadCheckpoint::new(
506            "https://example.com/file1.txt".to_string(),
507            PathBuf::from("/tmp/file1.txt"),
508            Some(1024),
509        );
510        let mut checkpoint2 = DownloadCheckpoint::new(
511            "https://example.com/file2.txt".to_string(),
512            PathBuf::from("/tmp/file2.txt"),
513            Some(1024),
514        );
515
516        // Manually set old timestamps (10 seconds ago)
517        let old_timestamp = SystemTime::now()
518            .duration_since(UNIX_EPOCH)
519            .unwrap_or_default()
520            .as_secs()
521            - 10;
522        checkpoint1.last_update = old_timestamp;
523        checkpoint2.last_update = old_timestamp;
524
525        let path1 =
526            fetcher.checkpoint_path("https://example.com/file1.txt", Path::new("/tmp/file1.txt"));
527        let path2 =
528            fetcher.checkpoint_path("https://example.com/file2.txt", Path::new("/tmp/file2.txt"));
529
530        fetcher.save_checkpoint(&checkpoint1, &path1).await.unwrap();
531        fetcher.save_checkpoint(&checkpoint2, &path2).await.unwrap();
532
533        // Clean up with max age of 5 seconds (should clean up 10-second-old checkpoints)
534        let cleaned = fetcher.cleanup_old_checkpoints(5).await.unwrap();
535        assert_eq!(cleaned, 2);
536
537        // Checkpoints should be gone
538        assert!(!path1.exists());
539        assert!(!path2.exists());
540    }
541}