use crate::{Error, HttpClient, Result};
use reqwest::Response;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use tokio::fs::{File, OpenOptions};
use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DownloadProgress {
pub total_size: Option<u64>,
pub bytes_downloaded: u64,
pub file_hash: String,
pub cdn_host: String,
pub cdn_path: String,
pub target_file: PathBuf,
pub progress_file: PathBuf,
pub is_complete: bool,
pub last_updated: u64,
}
#[derive(Debug)]
pub struct ResumableDownload {
client: HttpClient,
progress: DownloadProgress,
}
impl DownloadProgress {
pub fn new(
file_hash: String,
cdn_host: String,
cdn_path: String,
target_file: PathBuf,
) -> Self {
let progress_file = target_file.with_extension("download");
Self {
total_size: None,
bytes_downloaded: 0,
file_hash,
cdn_host,
cdn_path,
target_file,
progress_file,
is_complete: false,
last_updated: current_timestamp(),
}
}
pub async fn load_from_file(progress_file: &Path) -> Result<Self> {
let content = tokio::fs::read_to_string(progress_file).await?;
let mut progress: DownloadProgress = serde_json::from_str(&content)?;
progress.last_updated = current_timestamp();
Ok(progress)
}
pub async fn save_to_file(&self) -> Result<()> {
let content = serde_json::to_string_pretty(self)?;
tokio::fs::write(&self.progress_file, content).await?;
debug!("Saved download progress to {:?}", self.progress_file);
Ok(())
}
pub async fn verify_existing_file(&self) -> Result<bool> {
if let Ok(metadata) = tokio::fs::metadata(&self.target_file).await {
let file_size = metadata.len();
if let Some(total) = self.total_size {
return Ok(file_size == total);
}
Ok(file_size >= self.bytes_downloaded)
} else {
Ok(false)
}
}
pub fn completion_percentage(&self) -> Option<f64> {
self.total_size.map(|total| {
if total == 0 {
100.0
} else {
(self.bytes_downloaded as f64 / total as f64) * 100.0
}
})
}
pub fn progress_string(&self) -> String {
match (self.total_size, self.completion_percentage()) {
(Some(total), Some(percent)) => {
format!(
"{}/{} bytes ({:.1}%)",
format_bytes(self.bytes_downloaded),
format_bytes(total),
percent
)
}
(Some(total), None) => {
format!(
"{}/{} bytes",
format_bytes(self.bytes_downloaded),
format_bytes(total)
)
}
(None, _) => {
format!("{} bytes", format_bytes(self.bytes_downloaded))
}
}
}
}
impl ResumableDownload {
pub fn new(client: HttpClient, progress: DownloadProgress) -> Self {
Self { client, progress }
}
pub async fn start_or_resume(&mut self) -> Result<()> {
let can_resume = if self.progress.bytes_downloaded > 0 {
self.progress.verify_existing_file().await.unwrap_or(false)
} else {
false
};
if can_resume {
info!(
"Resuming download from {} bytes for {}",
self.progress.bytes_downloaded, self.progress.file_hash
);
} else {
info!("Starting new download for {}", self.progress.file_hash);
self.progress.bytes_downloaded = 0;
}
self.progress.save_to_file().await?;
self.download_with_resume().await
}
async fn download_with_resume(&mut self) -> Result<()> {
let mut file = OpenOptions::new()
.create(true)
.write(true)
.read(true)
.truncate(false)
.open(&self.progress.target_file)
.await?;
if self.progress.bytes_downloaded > 0 {
file.seek(SeekFrom::Start(self.progress.bytes_downloaded))
.await?;
}
let range = (self.progress.bytes_downloaded, None);
let response = self
.client
.download_file_range(
&self.progress.cdn_host,
&self.progress.cdn_path,
&self.progress.file_hash,
range,
)
.await?;
if self.progress.total_size.is_none() {
self.progress.total_size =
extract_total_size(&response, self.progress.bytes_downloaded);
}
match response.status() {
reqwest::StatusCode::PARTIAL_CONTENT => {
debug!(
"Server supports range requests, resuming from byte {}",
self.progress.bytes_downloaded
);
}
reqwest::StatusCode::OK => {
if self.progress.bytes_downloaded > 0 {
warn!(
"Server doesn't support range requests, restarting download from beginning"
);
file.seek(SeekFrom::Start(0)).await?;
file.set_len(0).await?;
self.progress.bytes_downloaded = 0;
}
}
_status => {
return Err(Error::InvalidResponse);
}
}
self.stream_response_to_file(response, &mut file).await?;
self.progress.is_complete = true;
self.progress.save_to_file().await?;
info!("Download completed: {}", self.progress.progress_string());
Ok(())
}
async fn stream_response_to_file(&mut self, response: Response, file: &mut File) -> Result<()> {
let mut stream = response.bytes_stream();
let mut bytes_written_since_save = 0u64;
const SAVE_INTERVAL: u64 = 1024 * 1024;
use futures_util::StreamExt;
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(Error::Http)?;
file.write_all(&chunk).await?;
let chunk_size = chunk.len() as u64;
self.progress.bytes_downloaded += chunk_size;
bytes_written_since_save += chunk_size;
if bytes_written_since_save >= SAVE_INTERVAL {
file.flush().await?;
self.progress.last_updated = current_timestamp();
self.progress.save_to_file().await?;
bytes_written_since_save = 0;
debug!("Progress: {}", self.progress.progress_string());
}
}
file.flush().await?;
self.progress.last_updated = current_timestamp();
Ok(())
}
pub fn progress(&self) -> &DownloadProgress {
&self.progress
}
pub async fn cancel(&self) -> Result<()> {
if self.progress.progress_file.exists() {
tokio::fs::remove_file(&self.progress.progress_file).await?;
debug!("Removed progress file {:?}", self.progress.progress_file);
}
Ok(())
}
pub async fn cleanup_completed(&self) -> Result<()> {
if self.progress.is_complete && self.progress.progress_file.exists() {
tokio::fs::remove_file(&self.progress.progress_file).await?;
debug!("Cleaned up progress file for completed download");
}
Ok(())
}
}
fn extract_total_size(response: &Response, bytes_already_downloaded: u64) -> Option<u64> {
if let Some(content_range) = response.headers().get("content-range") {
if let Ok(range_str) = content_range.to_str() {
if let Some(total_str) = range_str.split('/').nth(1) {
if let Ok(total) = total_str.parse::<u64>() {
return Some(total);
}
}
}
}
if let Some(content_length) = response.headers().get("content-length") {
if let Ok(length_str) = content_length.to_str() {
if let Ok(length) = length_str.parse::<u64>() {
return Some(length + bytes_already_downloaded);
}
}
}
None
}
fn format_bytes(bytes: u64) -> String {
const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB"];
let mut size = bytes as f64;
let mut unit_index = 0;
while size >= 1024.0 && unit_index < UNITS.len() - 1 {
size /= 1024.0;
unit_index += 1;
}
if unit_index == 0 {
format!("{} {}", bytes, UNITS[unit_index])
} else {
format!("{:.2} {}", size, UNITS[unit_index])
}
}
fn current_timestamp() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
pub async fn find_resumable_downloads(dir: &Path) -> Result<Vec<DownloadProgress>> {
let mut downloads = Vec::new();
if !dir.exists() {
return Ok(downloads);
}
let mut entries = tokio::fs::read_dir(dir).await?;
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("download") {
match DownloadProgress::load_from_file(&path).await {
Ok(progress) => {
if !progress.is_complete {
downloads.push(progress);
}
}
Err(e) => {
warn!("Failed to load download progress from {:?}: {}", path, e);
}
}
}
}
Ok(downloads)
}
pub async fn cleanup_old_progress_files(dir: &Path, max_age_hours: u64) -> Result<usize> {
let max_age_secs = max_age_hours * 3600;
let current_time = current_timestamp();
let mut cleaned_count = 0;
if !dir.exists() {
return Ok(0);
}
let mut entries = tokio::fs::read_dir(dir).await?;
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("download") {
match DownloadProgress::load_from_file(&path).await {
Ok(progress) => {
let age = current_time.saturating_sub(progress.last_updated);
if progress.is_complete
&& age > max_age_secs
&& tokio::fs::remove_file(&path).await.is_ok()
{
cleaned_count += 1;
debug!("Cleaned up old progress file: {:?}", path);
}
}
Err(_) => {
if let Ok(metadata) = tokio::fs::metadata(&path).await {
if let Ok(modified) = metadata.modified() {
let file_age = std::time::SystemTime::now()
.duration_since(modified)
.unwrap_or_default()
.as_secs();
if file_age > max_age_secs
&& tokio::fs::remove_file(&path).await.is_ok()
{
cleaned_count += 1;
debug!("Cleaned up corrupted progress file: {:?}", path);
}
}
}
}
}
}
}
Ok(cleaned_count)
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_format_bytes() {
assert_eq!(format_bytes(0), "0 B");
assert_eq!(format_bytes(512), "512 B");
assert_eq!(format_bytes(1024), "1.00 KB");
assert_eq!(format_bytes(1536), "1.50 KB");
assert_eq!(format_bytes(1048576), "1.00 MB");
assert_eq!(format_bytes(1073741824), "1.00 GB");
}
#[test]
fn test_completion_percentage() {
let mut progress = DownloadProgress::new(
"testhash".to_string(),
"cdn.test.com".to_string(),
"/data".to_string(),
PathBuf::from("/tmp/test.dat"),
);
assert!(progress.completion_percentage().is_none());
progress.total_size = Some(1000);
progress.bytes_downloaded = 250;
assert_eq!(progress.completion_percentage(), Some(25.0));
progress.bytes_downloaded = 1000;
assert_eq!(progress.completion_percentage(), Some(100.0));
progress.total_size = Some(0);
progress.bytes_downloaded = 0;
assert_eq!(progress.completion_percentage(), Some(100.0));
}
#[tokio::test]
async fn test_progress_persistence() {
let temp_dir = TempDir::new().unwrap();
let target_file = temp_dir.path().join("test.dat");
let mut progress = DownloadProgress::new(
"testhash123".to_string(),
"cdn.example.com".to_string(),
"/data".to_string(),
target_file,
);
progress.total_size = Some(2048);
progress.bytes_downloaded = 1024;
progress.save_to_file().await.unwrap();
assert!(progress.progress_file.exists());
let loaded_progress = DownloadProgress::load_from_file(&progress.progress_file)
.await
.unwrap();
assert_eq!(loaded_progress.file_hash, "testhash123");
assert_eq!(loaded_progress.total_size, Some(2048));
assert_eq!(loaded_progress.bytes_downloaded, 1024);
assert_eq!(loaded_progress.cdn_host, "cdn.example.com");
}
#[test]
fn test_extract_total_size_from_content_range() {
use reqwest::header::{HeaderMap, HeaderValue};
let client = reqwest::Client::new();
let _response = client.get("http://example.com").build().unwrap();
let mut headers = HeaderMap::new();
headers.insert(
"content-range",
HeaderValue::from_static("bytes 200-1023/2048"),
);
let content_range = "bytes 200-1023/2048";
let total: Option<u64> = content_range.split('/').nth(1).and_then(|s| s.parse().ok());
assert_eq!(total, Some(2048));
let content_length = "1024";
let length: Option<u64> = content_length.parse().ok();
assert_eq!(length, Some(1024));
}
}