vx_installer/
downloader.rs

1//! Download utilities for vx-installer
2
3use crate::{progress::ProgressContext, Error, Result, USER_AGENT};
4use futures_util::StreamExt;
5use sha2::Digest;
6use std::path::{Path, PathBuf};
7
8/// HTTP downloader for fetching files from URLs
9pub struct Downloader {
10    client: reqwest::Client,
11}
12
13impl Downloader {
14    /// Create a new downloader with default configuration
15    pub fn new() -> Result<Self> {
16        let client = reqwest::Client::builder()
17            .user_agent(USER_AGENT)
18            .timeout(std::time::Duration::from_secs(300)) // 5 minutes
19            .build()?;
20
21        Ok(Self { client })
22    }
23
24    /// Create a downloader with custom client configuration
25    pub fn with_client(client: reqwest::Client) -> Self {
26        Self { client }
27    }
28
29    /// Download a file from URL to the specified path
30    pub async fn download(
31        &self,
32        url: &str,
33        output_path: &Path,
34        progress: &ProgressContext,
35    ) -> Result<()> {
36        // Ensure parent directory exists
37        if let Some(parent) = output_path.parent() {
38            std::fs::create_dir_all(parent)?;
39        }
40
41        // Start the download request
42        let response = self
43            .client
44            .get(url)
45            .send()
46            .await
47            .map_err(|e| Error::download_failed(url, e.to_string()))?;
48
49        // Check response status
50        if !response.status().is_success() {
51            return Err(Error::download_failed(
52                url,
53                format!("HTTP {}", response.status()),
54            ));
55        }
56
57        // Get content length for progress tracking
58        let total_size = response.content_length();
59
60        // Extract filename for progress display
61        let filename = self.extract_filename_from_url(url);
62        let message = format!("Downloading {}", filename);
63
64        progress.start(&message, total_size).await?;
65
66        // Create the output file
67        let mut file = std::fs::File::create(output_path)?;
68        let mut stream = response.bytes_stream();
69        let mut downloaded = 0u64;
70
71        // Download with progress tracking
72        while let Some(chunk) = stream.next().await {
73            let chunk = chunk.map_err(|e| Error::download_failed(url, e.to_string()))?;
74
75            std::io::Write::write_all(&mut file, &chunk)?;
76            downloaded += chunk.len() as u64;
77
78            progress.update(downloaded, None).await?;
79        }
80
81        // Ensure all data is written
82        std::io::Write::flush(&mut file)?;
83
84        progress.finish("Download completed").await?;
85
86        Ok(())
87    }
88
89    /// Download a file to a temporary location and return the path
90    pub async fn download_temp(&self, url: &str, progress: &ProgressContext) -> Result<PathBuf> {
91        let filename = self.extract_filename_from_url(url);
92        let temp_dir = tempfile::tempdir()?;
93        let temp_path = temp_dir.path().join(filename);
94
95        self.download(url, &temp_path, progress).await?;
96
97        // Convert to a persistent path (caller is responsible for cleanup)
98        let persistent_path = temp_path.clone();
99        std::mem::forget(temp_dir); // Prevent automatic cleanup
100
101        Ok(persistent_path)
102    }
103
104    /// Download and verify checksum
105    pub async fn download_with_checksum(
106        &self,
107        url: &str,
108        output_path: &Path,
109        expected_checksum: &str,
110        progress: &ProgressContext,
111    ) -> Result<()> {
112        // Download the file
113        self.download(url, output_path, progress).await?;
114
115        // Verify checksum
116        let actual_checksum = self.calculate_sha256(output_path)?;
117        if actual_checksum != expected_checksum {
118            return Err(Error::ChecksumMismatch {
119                file_path: output_path.to_path_buf(),
120                expected: expected_checksum.to_string(),
121                actual: actual_checksum,
122            });
123        }
124
125        Ok(())
126    }
127
128    /// Get the size of a remote file without downloading it
129    pub async fn get_file_size(&self, url: &str) -> Result<Option<u64>> {
130        let response = self
131            .client
132            .head(url)
133            .send()
134            .await
135            .map_err(|e| Error::download_failed(url, e.to_string()))?;
136
137        if !response.status().is_success() {
138            return Err(Error::download_failed(
139                url,
140                format!("HTTP {}", response.status()),
141            ));
142        }
143
144        Ok(response.content_length())
145    }
146
147    /// Check if a URL is accessible
148    pub async fn check_url(&self, url: &str) -> Result<bool> {
149        match self.client.head(url).send().await {
150            Ok(response) => Ok(response.status().is_success()),
151            Err(_) => Ok(false),
152        }
153    }
154
155    /// Extract filename from URL
156    fn extract_filename_from_url(&self, url: &str) -> String {
157        let filename = url
158            .split('/')
159            .next_back()
160            .unwrap_or("download")
161            .split('?')
162            .next()
163            .unwrap_or("download");
164
165        if filename.is_empty() {
166            "download".to_string()
167        } else {
168            filename.to_string()
169        }
170    }
171
172    /// Calculate SHA256 checksum of a file
173    fn calculate_sha256(&self, file_path: &Path) -> Result<String> {
174        use std::io::Read;
175
176        let mut file = std::fs::File::open(file_path)?;
177        let mut hasher = sha2::Sha256::new();
178        let mut buffer = [0; 8192];
179
180        loop {
181            let bytes_read = file.read(&mut buffer)?;
182            if bytes_read == 0 {
183                break;
184            }
185            hasher.update(&buffer[..bytes_read]);
186        }
187
188        Ok(format!("{:x}", hasher.finalize()))
189    }
190}
191
192impl Default for Downloader {
193    fn default() -> Self {
194        Self::new().expect("Failed to create default downloader")
195    }
196}
197
198/// Configuration for download operations
199#[derive(Debug, Clone)]
200pub struct DownloadConfig {
201    /// URL to download from
202    pub url: String,
203    /// Output file path
204    pub output_path: PathBuf,
205    /// Expected checksum (optional)
206    pub checksum: Option<String>,
207    /// Maximum number of retry attempts
208    pub max_retries: u32,
209    /// Timeout for the download operation
210    pub timeout: std::time::Duration,
211    /// Whether to overwrite existing files
212    pub overwrite: bool,
213}
214
215impl DownloadConfig {
216    /// Create a new download configuration
217    pub fn new(url: impl Into<String>, output_path: impl Into<PathBuf>) -> Self {
218        Self {
219            url: url.into(),
220            output_path: output_path.into(),
221            checksum: None,
222            max_retries: 3,
223            timeout: std::time::Duration::from_secs(300),
224            overwrite: false,
225        }
226    }
227
228    /// Set the expected checksum
229    pub fn with_checksum(mut self, checksum: impl Into<String>) -> Self {
230        self.checksum = Some(checksum.into());
231        self
232    }
233
234    /// Set the maximum number of retries
235    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
236        self.max_retries = max_retries;
237        self
238    }
239
240    /// Set the timeout
241    pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
242        self.timeout = timeout;
243        self
244    }
245
246    /// Set whether to overwrite existing files
247    pub fn with_overwrite(mut self, overwrite: bool) -> Self {
248        self.overwrite = overwrite;
249        self
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[test]
258    fn test_extract_filename_from_url() {
259        let downloader = Downloader::default();
260
261        assert_eq!(
262            downloader.extract_filename_from_url("https://example.com/file.zip"),
263            "file.zip"
264        );
265        assert_eq!(
266            downloader.extract_filename_from_url("https://example.com/file.zip?version=1.0"),
267            "file.zip"
268        );
269        assert_eq!(
270            downloader.extract_filename_from_url("https://example.com/"),
271            "download"
272        );
273    }
274
275    #[test]
276    fn test_download_config() {
277        let config = DownloadConfig::new("https://example.com/file.zip", "/tmp/file.zip")
278            .with_checksum("abc123")
279            .with_max_retries(5)
280            .with_overwrite(true);
281
282        assert_eq!(config.url, "https://example.com/file.zip");
283        assert_eq!(config.output_path, PathBuf::from("/tmp/file.zip"));
284        assert_eq!(config.checksum, Some("abc123".to_string()));
285        assert_eq!(config.max_retries, 5);
286        assert!(config.overwrite);
287    }
288}