use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use tokio::fs;
use tokio::io::AsyncWriteExt;
use crate::config::{FetchOptions, FetchPhase};
use crate::error::{Error, Result};
use crate::fetch::fetcher::Fetcher;
use crate::net::http::HttpClient;
use crate::progress::Progress;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DownloadCheckpoint {
pub url: String,
pub destination: PathBuf,
pub total_bytes: Option<u64>,
pub downloaded_bytes: u64,
pub partial_checksum: Option<String>,
pub last_update: u64,
}
impl DownloadCheckpoint {
pub fn new(url: String, destination: PathBuf, total_bytes: Option<u64>) -> Self {
Self {
url,
destination,
total_bytes,
downloaded_bytes: 0,
partial_checksum: None,
last_update: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
}
}
pub fn update_progress(&mut self, downloaded_bytes: u64) {
self.downloaded_bytes = downloaded_bytes;
self.last_update = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
}
pub fn can_resume(&self) -> bool {
self.downloaded_bytes > 0
}
pub fn range_header(&self) -> String {
format!("bytes={}-", self.downloaded_bytes)
}
}
pub struct ResumableFetcher<C: HttpClient> {
base_fetcher: Fetcher<C>,
checkpoint_dir: PathBuf,
}
impl<C: HttpClient + 'static> ResumableFetcher<C> {
pub fn new(client: C, workspace_root: impl Into<PathBuf>) -> Self {
let workspace_root = workspace_root.into();
Self {
base_fetcher: Fetcher::new(client, workspace_root.clone()),
checkpoint_dir: workspace_root.join(".checkpoints"),
}
}
pub async fn fetch_resumable(
&self,
url: &str,
destination: &Path,
options: FetchOptions,
) -> Result<PathBuf> {
fs::create_dir_all(&self.checkpoint_dir)
.await
.map_err(|e| Error::Network(e.to_string()))?;
let checkpoint_path = self.checkpoint_path(url, destination);
if let Ok(checkpoint) = self.load_checkpoint(&checkpoint_path).await
&& checkpoint.can_resume()
{
return self
.resume_download(&checkpoint, &checkpoint_path, options)
.await;
}
self.start_new_download(url, destination, &checkpoint_path, options)
.await
}
async fn start_new_download(
&self,
url: &str,
destination: &Path,
checkpoint_path: &Path,
options: FetchOptions,
) -> Result<PathBuf> {
let total_bytes = self
.base_fetcher
.head(url)
.await
.map_err(|e| Error::Network(e.to_string()))?;
let checkpoint =
DownloadCheckpoint::new(url.to_string(), destination.to_path_buf(), total_bytes);
self.save_checkpoint(&checkpoint, checkpoint_path).await?;
let checkpoint_path_clone = checkpoint_path.to_path_buf();
let checkpoint_dir = self.checkpoint_dir.clone();
let url_clone = url.to_string();
let destination_clone = destination.to_path_buf();
let mut options_with_checkpoint = options.clone();
let original_callback = options_with_checkpoint.on_progress.clone();
options_with_checkpoint.on_progress = Some(Arc::new(move |progress: &Progress| {
if progress.phase == FetchPhase::Downloading {
let mut checkpoint = DownloadCheckpoint::new(
url_clone.clone(),
destination_clone.clone(),
total_bytes,
);
checkpoint.update_progress(progress.bytes_downloaded);
let checkpoint_path = checkpoint_path_clone.clone();
let checkpoint_dir = checkpoint_dir.clone();
tokio::spawn(async move {
let _ = Self::save_checkpoint_static(
&checkpoint,
&checkpoint_path,
&checkpoint_dir,
)
.await;
});
}
if let Some(ref callback) = original_callback {
callback(progress);
}
}));
let result = self
.base_fetcher
.fetch_with_receipt(url, destination, options_with_checkpoint)
.await;
match result {
Ok(receipt) => {
let _ = fs::remove_file(checkpoint_path).await;
Ok(receipt.destination)
}
Err(e) => {
Err(e)
}
}
}
async fn resume_download(
&self,
checkpoint: &DownloadCheckpoint,
checkpoint_path: &Path,
options: FetchOptions,
) -> Result<PathBuf> {
if checkpoint.destination.exists() {
let metadata = fs::metadata(&checkpoint.destination)
.await
.map_err(|e| Error::Network(e.to_string()))?;
let current_size = metadata.len();
if current_size != checkpoint.downloaded_bytes {
let _ = fs::remove_file(&checkpoint.destination).await;
let _ = fs::remove_file(checkpoint_path).await;
return self
.start_new_download(
&checkpoint.url,
&checkpoint.destination,
checkpoint_path,
options,
)
.await;
}
}
let mut resume_options = options.clone();
resume_options.resume_offset = Some(checkpoint.downloaded_bytes);
let checkpoint_path_clone = checkpoint_path.to_path_buf();
let checkpoint_dir = self.checkpoint_dir.clone();
let initial_bytes = checkpoint.downloaded_bytes;
let original_total_bytes = checkpoint.total_bytes;
let checkpoint_url = checkpoint.url.clone();
let checkpoint_destination = checkpoint.destination.clone();
let original_callback = resume_options.on_progress.clone();
resume_options.on_progress = Some(Arc::new(move |progress: &Progress| {
if progress.phase == FetchPhase::Downloading {
let total_downloaded = initial_bytes + progress.bytes_downloaded;
let mut new_checkpoint = DownloadCheckpoint::new(
checkpoint_url.clone(),
checkpoint_destination.clone(),
original_total_bytes,
);
new_checkpoint.update_progress(total_downloaded);
let checkpoint_path = checkpoint_path_clone.clone();
let checkpoint_dir = checkpoint_dir.clone();
tokio::spawn(async move {
let _ = Self::save_checkpoint_static(
&new_checkpoint,
&checkpoint_path,
&checkpoint_dir,
)
.await;
});
}
if let Some(ref callback) = original_callback {
callback(progress);
}
}));
let result = self
.base_fetcher
.fetch_with_receipt(&checkpoint.url, &checkpoint.destination, resume_options)
.await;
match result {
Ok(receipt) => {
let _ = fs::remove_file(checkpoint_path).await;
Ok(receipt.destination)
}
Err(e) => {
Err(e)
}
}
}
fn checkpoint_path(&self, url: &str, destination: &Path) -> PathBuf {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
url.hash(&mut hasher);
destination.hash(&mut hasher);
let hash = hasher.finish();
self.checkpoint_dir
.join(format!("checkpoint_{:016x}.json", hash))
}
async fn load_checkpoint(&self, path: &Path) -> Result<DownloadCheckpoint> {
let content = fs::read_to_string(path)
.await
.map_err(|e| Error::Network(e.to_string()))?;
serde_json::from_str(&content)
.map_err(|e| Error::InvalidState(format!("Invalid checkpoint: {}", e)))
}
async fn save_checkpoint(&self, checkpoint: &DownloadCheckpoint, path: &Path) -> Result<()> {
Self::save_checkpoint_static(checkpoint, path, &self.checkpoint_dir).await
}
async fn save_checkpoint_static(
checkpoint: &DownloadCheckpoint,
path: &Path,
checkpoint_dir: &Path,
) -> Result<()> {
fs::create_dir_all(checkpoint_dir)
.await
.map_err(|e| Error::Network(e.to_string()))?;
let content = serde_json::to_string_pretty(checkpoint)
.map_err(|e| Error::InvalidState(format!("Failed to serialize checkpoint: {}", e)))?;
let temp_path = path.with_extension("tmp");
{
let mut file: tokio::fs::File = fs::File::create(&temp_path)
.await
.map_err(|e| Error::Network(e.to_string()))?;
file.write_all(content.as_bytes())
.await
.map_err(|e| Error::Network(e.to_string()))?;
file.sync_all()
.await
.map_err(|e| Error::Network(e.to_string()))?;
}
fs::rename(&temp_path, path)
.await
.map_err(|e| Error::Network(e.to_string()))?;
Ok(())
}
pub async fn cleanup_old_checkpoints(&self, max_age_seconds: u64) -> Result<usize> {
let mut cleaned = 0;
let cutoff = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
- max_age_seconds;
let mut entries = fs::read_dir(&self.checkpoint_dir)
.await
.map_err(|e| Error::Network(e.to_string()))?;
while let Some(entry) = entries
.next_entry()
.await
.map_err(|e| Error::Network(e.to_string()))?
{
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("json") {
match self.load_checkpoint(&path).await {
Ok(checkpoint) => {
if checkpoint.last_update < cutoff {
let _ = fs::remove_file(&path).await;
cleaned += 1;
}
}
Err(_) => {
let _ = fs::remove_file(&path).await;
cleaned += 1;
}
}
}
}
Ok(cleaned)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::net::http::BoxStream;
use bytes::Bytes;
use tempfile::TempDir;
#[derive(Debug)]
struct MockClient;
impl MockClient {
fn new() -> Self {
Self
}
}
#[derive(Debug)]
struct MockError(String);
impl std::fmt::Display for MockError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for MockError {}
impl HttpClient for MockClient {
type Error = MockError;
async fn stream(
&self,
_url: &str,
_headers: &[(String, String)],
) -> std::result::Result<
BoxStream<'static, std::result::Result<Bytes, Self::Error>>,
Self::Error,
> {
let empty: BoxStream<'static, std::result::Result<Bytes, Self::Error>> =
Box::pin(futures_util::stream::empty());
Ok(empty)
}
async fn head(&self, _url: &str) -> std::result::Result<Option<u64>, Self::Error> {
Ok(Some(1024))
}
}
#[test]
fn test_download_checkpoint() {
let mut checkpoint = DownloadCheckpoint::new(
"https://example.com/file.txt".to_string(),
PathBuf::from("/tmp/file.txt"),
Some(1024),
);
assert_eq!(checkpoint.downloaded_bytes, 0);
assert!(!checkpoint.can_resume());
assert_eq!(checkpoint.range_header(), "bytes=0-");
checkpoint.update_progress(512);
assert_eq!(checkpoint.downloaded_bytes, 512);
assert!(checkpoint.can_resume());
assert_eq!(checkpoint.range_header(), "bytes=512-");
}
#[tokio::test]
async fn test_checkpoint_save_load() {
let temp_dir = TempDir::new().unwrap();
let checkpoint_path = temp_dir.path().join("checkpoint.json");
let mut original = DownloadCheckpoint::new(
"https://example.com/file.txt".to_string(),
PathBuf::from("/tmp/file.txt"),
Some(1024),
);
assert_eq!(original.downloaded_bytes, 0);
assert!(!original.can_resume());
assert_eq!(original.range_header(), "bytes=0-");
original.update_progress(512);
let fetcher: ResumableFetcher<MockClient> =
ResumableFetcher::new(MockClient::new(), temp_dir.path());
fetcher
.save_checkpoint(&original, &checkpoint_path)
.await
.unwrap();
let loaded: DownloadCheckpoint = fetcher.load_checkpoint(&checkpoint_path).await.unwrap();
assert_eq!(loaded.url, original.url);
assert_eq!(loaded.destination, original.destination);
assert_eq!(loaded.total_bytes, original.total_bytes);
assert_eq!(loaded.downloaded_bytes, original.downloaded_bytes);
}
#[tokio::test]
async fn test_cleanup_old_checkpoints() {
let temp_dir = TempDir::new().unwrap();
let fetcher = ResumableFetcher::<MockClient>::new(MockClient::new(), temp_dir.path());
let mut checkpoint1 = DownloadCheckpoint::new(
"https://example.com/file1.txt".to_string(),
PathBuf::from("/tmp/file1.txt"),
Some(1024),
);
let mut checkpoint2 = DownloadCheckpoint::new(
"https://example.com/file2.txt".to_string(),
PathBuf::from("/tmp/file2.txt"),
Some(1024),
);
let old_timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
- 10;
checkpoint1.last_update = old_timestamp;
checkpoint2.last_update = old_timestamp;
let path1 =
fetcher.checkpoint_path("https://example.com/file1.txt", Path::new("/tmp/file1.txt"));
let path2 =
fetcher.checkpoint_path("https://example.com/file2.txt", Path::new("/tmp/file2.txt"));
fetcher.save_checkpoint(&checkpoint1, &path1).await.unwrap();
fetcher.save_checkpoint(&checkpoint2, &path2).await.unwrap();
let cleaned = fetcher.cleanup_old_checkpoints(5).await.unwrap();
assert_eq!(cleaned, 2);
assert!(!path1.exists());
assert!(!path2.exists());
}
}