tact_client/
resumable.rs

1//! Resumable download functionality for TACT clients
2//!
3//! This module provides support for downloading files that can be interrupted and resumed
4//! from the last successfully downloaded byte. It persists download state to disk and
5//! uses HTTP range requests to continue interrupted downloads.
6
7use crate::{Error, HttpClient, Result};
8use reqwest::Response;
9use serde::{Deserialize, Serialize};
10use std::path::{Path, PathBuf};
11use tokio::fs::{File, OpenOptions};
12use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
13use tracing::{debug, info, warn};
14
15/// Download progress information persisted to disk
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct DownloadProgress {
18    /// Total expected file size in bytes
19    pub total_size: Option<u64>,
20    /// Number of bytes downloaded so far
21    pub bytes_downloaded: u64,
22    /// Original file hash for verification
23    pub file_hash: String,
24    /// CDN host used for download
25    pub cdn_host: String,
26    /// CDN path for the file
27    pub cdn_path: String,
28    /// Target file path where content is being written
29    pub target_file: PathBuf,
30    /// Progress file path for state persistence
31    pub progress_file: PathBuf,
32    /// Whether the download is complete
33    pub is_complete: bool,
34    /// Timestamp of last update (for cleanup of old progress files)
35    pub last_updated: u64,
36}
37
38/// Resumable download manager
39#[derive(Debug)]
40pub struct ResumableDownload {
41    client: HttpClient,
42    progress: DownloadProgress,
43}
44
45impl DownloadProgress {
46    /// Create a new download progress tracker
47    pub fn new(
48        file_hash: String,
49        cdn_host: String,
50        cdn_path: String,
51        target_file: PathBuf,
52    ) -> Self {
53        let progress_file = target_file.with_extension("download");
54
55        Self {
56            total_size: None,
57            bytes_downloaded: 0,
58            file_hash,
59            cdn_host,
60            cdn_path,
61            target_file,
62            progress_file,
63            is_complete: false,
64            last_updated: current_timestamp(),
65        }
66    }
67
68    /// Load progress from disk
69    pub async fn load_from_file(progress_file: &Path) -> Result<Self> {
70        let content = tokio::fs::read_to_string(progress_file).await?;
71        let mut progress: DownloadProgress = serde_json::from_str(&content)?;
72        progress.last_updated = current_timestamp();
73        Ok(progress)
74    }
75
76    /// Save progress to disk
77    pub async fn save_to_file(&self) -> Result<()> {
78        let content = serde_json::to_string_pretty(self)?;
79        tokio::fs::write(&self.progress_file, content).await?;
80        debug!("Saved download progress to {:?}", self.progress_file);
81        Ok(())
82    }
83
84    /// Check if the target file exists and has the expected size
85    pub async fn verify_existing_file(&self) -> Result<bool> {
86        if let Ok(metadata) = tokio::fs::metadata(&self.target_file).await {
87            let file_size = metadata.len();
88
89            // If we know the total size, check if it matches
90            if let Some(total) = self.total_size {
91                return Ok(file_size == total);
92            }
93
94            // If file exists and we've downloaded some bytes, assume it's valid for resume
95            Ok(file_size >= self.bytes_downloaded)
96        } else {
97            Ok(false)
98        }
99    }
100
101    /// Calculate download completion percentage
102    pub fn completion_percentage(&self) -> Option<f64> {
103        self.total_size.map(|total| {
104            if total == 0 {
105                100.0
106            } else {
107                (self.bytes_downloaded as f64 / total as f64) * 100.0
108            }
109        })
110    }
111
112    /// Get human-readable progress string
113    pub fn progress_string(&self) -> String {
114        match (self.total_size, self.completion_percentage()) {
115            (Some(total), Some(percent)) => {
116                format!(
117                    "{}/{} bytes ({:.1}%)",
118                    format_bytes(self.bytes_downloaded),
119                    format_bytes(total),
120                    percent
121                )
122            }
123            (Some(total), None) => {
124                format!(
125                    "{}/{} bytes",
126                    format_bytes(self.bytes_downloaded),
127                    format_bytes(total)
128                )
129            }
130            (None, _) => {
131                format!("{} bytes", format_bytes(self.bytes_downloaded))
132            }
133        }
134    }
135}
136
137impl ResumableDownload {
138    /// Create a new resumable download
139    pub fn new(client: HttpClient, progress: DownloadProgress) -> Self {
140        Self { client, progress }
141    }
142
143    /// Start or resume a download
144    pub async fn start_or_resume(&mut self) -> Result<()> {
145        // Check if we can resume from existing file
146        let can_resume = if self.progress.bytes_downloaded > 0 {
147            self.progress.verify_existing_file().await.unwrap_or(false)
148        } else {
149            false
150        };
151
152        if can_resume {
153            info!(
154                "Resuming download from {} bytes for {}",
155                self.progress.bytes_downloaded, self.progress.file_hash
156            );
157        } else {
158            info!("Starting new download for {}", self.progress.file_hash);
159            self.progress.bytes_downloaded = 0;
160        }
161
162        // Save initial progress
163        self.progress.save_to_file().await?;
164
165        // Start the download
166        self.download_with_resume().await
167    }
168
169    /// Perform the actual download with resume capability
170    async fn download_with_resume(&mut self) -> Result<()> {
171        // Open or create the target file
172        let mut file = OpenOptions::new()
173            .create(true)
174            .write(true)
175            .read(true)
176            .truncate(false)
177            .open(&self.progress.target_file)
178            .await?;
179
180        // Seek to the resume position
181        if self.progress.bytes_downloaded > 0 {
182            file.seek(SeekFrom::Start(self.progress.bytes_downloaded))
183                .await?;
184        }
185
186        // Make range request from resume position
187        let range = (self.progress.bytes_downloaded, None);
188        let response = self
189            .client
190            .download_file_range(
191                &self.progress.cdn_host,
192                &self.progress.cdn_path,
193                &self.progress.file_hash,
194                range,
195            )
196            .await?;
197
198        // Extract total size from headers if available
199        if self.progress.total_size.is_none() {
200            self.progress.total_size =
201                extract_total_size(&response, self.progress.bytes_downloaded);
202        }
203
204        // Check response status
205        match response.status() {
206            reqwest::StatusCode::PARTIAL_CONTENT => {
207                debug!(
208                    "Server supports range requests, resuming from byte {}",
209                    self.progress.bytes_downloaded
210                );
211            }
212            reqwest::StatusCode::OK => {
213                if self.progress.bytes_downloaded > 0 {
214                    warn!(
215                        "Server doesn't support range requests, restarting download from beginning"
216                    );
217                    file.seek(SeekFrom::Start(0)).await?;
218                    file.set_len(0).await?;
219                    self.progress.bytes_downloaded = 0;
220                }
221            }
222            _status => {
223                return Err(Error::InvalidResponse);
224            }
225        }
226
227        // Stream the response to the file with progress tracking
228        self.stream_response_to_file(response, &mut file).await?;
229
230        // Mark as complete and clean up
231        self.progress.is_complete = true;
232        self.progress.save_to_file().await?;
233
234        info!("Download completed: {}", self.progress.progress_string());
235        Ok(())
236    }
237
238    /// Stream response content to file with progress updates
239    async fn stream_response_to_file(&mut self, response: Response, file: &mut File) -> Result<()> {
240        let mut stream = response.bytes_stream();
241        let mut bytes_written_since_save = 0u64;
242        const SAVE_INTERVAL: u64 = 1024 * 1024; // Save progress every 1MB
243
244        use futures_util::StreamExt;
245
246        while let Some(chunk_result) = stream.next().await {
247            let chunk = chunk_result.map_err(Error::Http)?;
248
249            // Write chunk to file
250            file.write_all(&chunk).await?;
251
252            // Update progress
253            let chunk_size = chunk.len() as u64;
254            self.progress.bytes_downloaded += chunk_size;
255            bytes_written_since_save += chunk_size;
256
257            // Periodically save progress to disk
258            if bytes_written_since_save >= SAVE_INTERVAL {
259                file.flush().await?;
260                self.progress.last_updated = current_timestamp();
261                self.progress.save_to_file().await?;
262                bytes_written_since_save = 0;
263
264                debug!("Progress: {}", self.progress.progress_string());
265            }
266        }
267
268        // Final flush and progress save
269        file.flush().await?;
270        self.progress.last_updated = current_timestamp();
271
272        Ok(())
273    }
274
275    /// Get current progress
276    pub fn progress(&self) -> &DownloadProgress {
277        &self.progress
278    }
279
280    /// Cancel the download and clean up progress file
281    pub async fn cancel(&self) -> Result<()> {
282        if self.progress.progress_file.exists() {
283            tokio::fs::remove_file(&self.progress.progress_file).await?;
284            debug!("Removed progress file {:?}", self.progress.progress_file);
285        }
286        Ok(())
287    }
288
289    /// Clean up completed download (remove progress file, keep target file)
290    pub async fn cleanup_completed(&self) -> Result<()> {
291        if self.progress.is_complete && self.progress.progress_file.exists() {
292            tokio::fs::remove_file(&self.progress.progress_file).await?;
293            debug!("Cleaned up progress file for completed download");
294        }
295        Ok(())
296    }
297}
298
299/// Extract total file size from HTTP response headers
300fn extract_total_size(response: &Response, bytes_already_downloaded: u64) -> Option<u64> {
301    // Try Content-Range header first (for partial content)
302    if let Some(content_range) = response.headers().get("content-range") {
303        if let Ok(range_str) = content_range.to_str() {
304            // Format: "bytes 200-1023/1024"
305            if let Some(total_str) = range_str.split('/').nth(1) {
306                if let Ok(total) = total_str.parse::<u64>() {
307                    return Some(total);
308                }
309            }
310        }
311    }
312
313    // Fall back to Content-Length header
314    if let Some(content_length) = response.headers().get("content-length") {
315        if let Ok(length_str) = content_length.to_str() {
316            if let Ok(length) = length_str.parse::<u64>() {
317                // If this is a partial response, add the bytes we already have
318                return Some(length + bytes_already_downloaded);
319            }
320        }
321    }
322
323    None
324}
325
326/// Format bytes in human-readable format
327fn format_bytes(bytes: u64) -> String {
328    const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB"];
329    let mut size = bytes as f64;
330    let mut unit_index = 0;
331
332    while size >= 1024.0 && unit_index < UNITS.len() - 1 {
333        size /= 1024.0;
334        unit_index += 1;
335    }
336
337    if unit_index == 0 {
338        format!("{} {}", bytes, UNITS[unit_index])
339    } else {
340        format!("{:.2} {}", size, UNITS[unit_index])
341    }
342}
343
344/// Get current timestamp in seconds since Unix epoch
345fn current_timestamp() -> u64 {
346    std::time::SystemTime::now()
347        .duration_since(std::time::UNIX_EPOCH)
348        .unwrap_or_default()
349        .as_secs()
350}
351
352/// Find all resumable downloads in a directory
353pub async fn find_resumable_downloads(dir: &Path) -> Result<Vec<DownloadProgress>> {
354    let mut downloads = Vec::new();
355
356    if !dir.exists() {
357        return Ok(downloads);
358    }
359
360    let mut entries = tokio::fs::read_dir(dir).await?;
361
362    while let Some(entry) = entries.next_entry().await? {
363        let path = entry.path();
364
365        if path.extension().and_then(|s| s.to_str()) == Some("download") {
366            match DownloadProgress::load_from_file(&path).await {
367                Ok(progress) => {
368                    if !progress.is_complete {
369                        downloads.push(progress);
370                    }
371                }
372                Err(e) => {
373                    warn!("Failed to load download progress from {:?}: {}", path, e);
374                }
375            }
376        }
377    }
378
379    Ok(downloads)
380}
381
382/// Clean up old completed download progress files
383pub async fn cleanup_old_progress_files(dir: &Path, max_age_hours: u64) -> Result<usize> {
384    let max_age_secs = max_age_hours * 3600;
385    let current_time = current_timestamp();
386    let mut cleaned_count = 0;
387
388    if !dir.exists() {
389        return Ok(0);
390    }
391
392    let mut entries = tokio::fs::read_dir(dir).await?;
393
394    while let Some(entry) = entries.next_entry().await? {
395        let path = entry.path();
396
397        if path.extension().and_then(|s| s.to_str()) == Some("download") {
398            match DownloadProgress::load_from_file(&path).await {
399                Ok(progress) => {
400                    let age = current_time.saturating_sub(progress.last_updated);
401
402                    if progress.is_complete
403                        && age > max_age_secs
404                        && tokio::fs::remove_file(&path).await.is_ok()
405                    {
406                        cleaned_count += 1;
407                        debug!("Cleaned up old progress file: {:?}", path);
408                    }
409                }
410                Err(_) => {
411                    // If we can't parse the progress file, it might be corrupted
412                    // Clean it up if it's old enough based on file modification time
413                    if let Ok(metadata) = tokio::fs::metadata(&path).await {
414                        if let Ok(modified) = metadata.modified() {
415                            let file_age = std::time::SystemTime::now()
416                                .duration_since(modified)
417                                .unwrap_or_default()
418                                .as_secs();
419
420                            if file_age > max_age_secs
421                                && tokio::fs::remove_file(&path).await.is_ok()
422                            {
423                                cleaned_count += 1;
424                                debug!("Cleaned up corrupted progress file: {:?}", path);
425                            }
426                        }
427                    }
428                }
429            }
430        }
431    }
432
433    Ok(cleaned_count)
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439    use tempfile::TempDir;
440
441    #[test]
442    fn test_format_bytes() {
443        assert_eq!(format_bytes(0), "0 B");
444        assert_eq!(format_bytes(512), "512 B");
445        assert_eq!(format_bytes(1024), "1.00 KB");
446        assert_eq!(format_bytes(1536), "1.50 KB");
447        assert_eq!(format_bytes(1048576), "1.00 MB");
448        assert_eq!(format_bytes(1073741824), "1.00 GB");
449    }
450
451    #[test]
452    fn test_completion_percentage() {
453        let mut progress = DownloadProgress::new(
454            "testhash".to_string(),
455            "cdn.test.com".to_string(),
456            "/data".to_string(),
457            PathBuf::from("/tmp/test.dat"),
458        );
459
460        // No total size set
461        assert!(progress.completion_percentage().is_none());
462
463        // With total size
464        progress.total_size = Some(1000);
465        progress.bytes_downloaded = 250;
466        assert_eq!(progress.completion_percentage(), Some(25.0));
467
468        // Complete download
469        progress.bytes_downloaded = 1000;
470        assert_eq!(progress.completion_percentage(), Some(100.0));
471
472        // Zero-byte file
473        progress.total_size = Some(0);
474        progress.bytes_downloaded = 0;
475        assert_eq!(progress.completion_percentage(), Some(100.0));
476    }
477
478    #[tokio::test]
479    async fn test_progress_persistence() {
480        let temp_dir = TempDir::new().unwrap();
481        let target_file = temp_dir.path().join("test.dat");
482
483        let mut progress = DownloadProgress::new(
484            "testhash123".to_string(),
485            "cdn.example.com".to_string(),
486            "/data".to_string(),
487            target_file,
488        );
489
490        progress.total_size = Some(2048);
491        progress.bytes_downloaded = 1024;
492
493        // Save progress
494        progress.save_to_file().await.unwrap();
495        assert!(progress.progress_file.exists());
496
497        // Load progress
498        let loaded_progress = DownloadProgress::load_from_file(&progress.progress_file)
499            .await
500            .unwrap();
501        assert_eq!(loaded_progress.file_hash, "testhash123");
502        assert_eq!(loaded_progress.total_size, Some(2048));
503        assert_eq!(loaded_progress.bytes_downloaded, 1024);
504        assert_eq!(loaded_progress.cdn_host, "cdn.example.com");
505    }
506
507    #[test]
508    fn test_extract_total_size_from_content_range() {
509        use reqwest::header::{HeaderMap, HeaderValue};
510
511        let client = reqwest::Client::new();
512        let _response = client.get("http://example.com").build().unwrap();
513
514        // Mock response with content-range header
515        let mut headers = HeaderMap::new();
516        headers.insert(
517            "content-range",
518            HeaderValue::from_static("bytes 200-1023/2048"),
519        );
520
521        // We can't directly set headers on a Response, so we'll test the parsing logic
522        let content_range = "bytes 200-1023/2048";
523        let total: Option<u64> = content_range.split('/').nth(1).and_then(|s| s.parse().ok());
524        assert_eq!(total, Some(2048));
525
526        // Test with content-length fallback
527        let content_length = "1024";
528        let length: Option<u64> = content_length.parse().ok();
529        assert_eq!(length, Some(1024));
530    }
531}