vx_installer/
downloader.rs1use crate::{progress::ProgressContext, Error, Result, USER_AGENT};
4use futures_util::StreamExt;
5use sha2::Digest;
6use std::path::{Path, PathBuf};
7
8pub struct Downloader {
10 client: reqwest::Client,
11}
12
13impl Downloader {
14 pub fn new() -> Result<Self> {
16 let client = reqwest::Client::builder()
17 .user_agent(USER_AGENT)
18 .timeout(std::time::Duration::from_secs(300)) .build()?;
20
21 Ok(Self { client })
22 }
23
24 pub fn with_client(client: reqwest::Client) -> Self {
26 Self { client }
27 }
28
29 pub async fn download(
31 &self,
32 url: &str,
33 output_path: &Path,
34 progress: &ProgressContext,
35 ) -> Result<()> {
36 if let Some(parent) = output_path.parent() {
38 std::fs::create_dir_all(parent)?;
39 }
40
41 let response = self
43 .client
44 .get(url)
45 .send()
46 .await
47 .map_err(|e| Error::download_failed(url, e.to_string()))?;
48
49 if !response.status().is_success() {
51 return Err(Error::download_failed(
52 url,
53 format!("HTTP {}", response.status()),
54 ));
55 }
56
57 let total_size = response.content_length();
59
60 let filename = self.extract_filename_from_url(url);
62 let message = format!("Downloading {}", filename);
63
64 progress.start(&message, total_size).await?;
65
66 let mut file = std::fs::File::create(output_path)?;
68 let mut stream = response.bytes_stream();
69 let mut downloaded = 0u64;
70
71 while let Some(chunk) = stream.next().await {
73 let chunk = chunk.map_err(|e| Error::download_failed(url, e.to_string()))?;
74
75 std::io::Write::write_all(&mut file, &chunk)?;
76 downloaded += chunk.len() as u64;
77
78 progress.update(downloaded, None).await?;
79 }
80
81 std::io::Write::flush(&mut file)?;
83
84 progress.finish("Download completed").await?;
85
86 Ok(())
87 }
88
89 pub async fn download_temp(&self, url: &str, progress: &ProgressContext) -> Result<PathBuf> {
91 let filename = self.extract_filename_from_url(url);
92 let temp_dir = tempfile::tempdir()?;
93 let temp_path = temp_dir.path().join(filename);
94
95 self.download(url, &temp_path, progress).await?;
96
97 let persistent_path = temp_path.clone();
99 std::mem::forget(temp_dir); Ok(persistent_path)
102 }
103
104 pub async fn download_with_checksum(
106 &self,
107 url: &str,
108 output_path: &Path,
109 expected_checksum: &str,
110 progress: &ProgressContext,
111 ) -> Result<()> {
112 self.download(url, output_path, progress).await?;
114
115 let actual_checksum = self.calculate_sha256(output_path)?;
117 if actual_checksum != expected_checksum {
118 return Err(Error::ChecksumMismatch {
119 file_path: output_path.to_path_buf(),
120 expected: expected_checksum.to_string(),
121 actual: actual_checksum,
122 });
123 }
124
125 Ok(())
126 }
127
128 pub async fn get_file_size(&self, url: &str) -> Result<Option<u64>> {
130 let response = self
131 .client
132 .head(url)
133 .send()
134 .await
135 .map_err(|e| Error::download_failed(url, e.to_string()))?;
136
137 if !response.status().is_success() {
138 return Err(Error::download_failed(
139 url,
140 format!("HTTP {}", response.status()),
141 ));
142 }
143
144 Ok(response.content_length())
145 }
146
147 pub async fn check_url(&self, url: &str) -> Result<bool> {
149 match self.client.head(url).send().await {
150 Ok(response) => Ok(response.status().is_success()),
151 Err(_) => Ok(false),
152 }
153 }
154
155 fn extract_filename_from_url(&self, url: &str) -> String {
157 let filename = url
158 .split('/')
159 .next_back()
160 .unwrap_or("download")
161 .split('?')
162 .next()
163 .unwrap_or("download");
164
165 if filename.is_empty() {
166 "download".to_string()
167 } else {
168 filename.to_string()
169 }
170 }
171
172 fn calculate_sha256(&self, file_path: &Path) -> Result<String> {
174 use std::io::Read;
175
176 let mut file = std::fs::File::open(file_path)?;
177 let mut hasher = sha2::Sha256::new();
178 let mut buffer = [0; 8192];
179
180 loop {
181 let bytes_read = file.read(&mut buffer)?;
182 if bytes_read == 0 {
183 break;
184 }
185 hasher.update(&buffer[..bytes_read]);
186 }
187
188 Ok(format!("{:x}", hasher.finalize()))
189 }
190}
191
192impl Default for Downloader {
193 fn default() -> Self {
194 Self::new().expect("Failed to create default downloader")
195 }
196}
197
198#[derive(Debug, Clone)]
200pub struct DownloadConfig {
201 pub url: String,
203 pub output_path: PathBuf,
205 pub checksum: Option<String>,
207 pub max_retries: u32,
209 pub timeout: std::time::Duration,
211 pub overwrite: bool,
213}
214
215impl DownloadConfig {
216 pub fn new(url: impl Into<String>, output_path: impl Into<PathBuf>) -> Self {
218 Self {
219 url: url.into(),
220 output_path: output_path.into(),
221 checksum: None,
222 max_retries: 3,
223 timeout: std::time::Duration::from_secs(300),
224 overwrite: false,
225 }
226 }
227
228 pub fn with_checksum(mut self, checksum: impl Into<String>) -> Self {
230 self.checksum = Some(checksum.into());
231 self
232 }
233
234 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
236 self.max_retries = max_retries;
237 self
238 }
239
240 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
242 self.timeout = timeout;
243 self
244 }
245
246 pub fn with_overwrite(mut self, overwrite: bool) -> Self {
248 self.overwrite = overwrite;
249 self
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[test]
258 fn test_extract_filename_from_url() {
259 let downloader = Downloader::default();
260
261 assert_eq!(
262 downloader.extract_filename_from_url("https://example.com/file.zip"),
263 "file.zip"
264 );
265 assert_eq!(
266 downloader.extract_filename_from_url("https://example.com/file.zip?version=1.0"),
267 "file.zip"
268 );
269 assert_eq!(
270 downloader.extract_filename_from_url("https://example.com/"),
271 "download"
272 );
273 }
274
275 #[test]
276 fn test_download_config() {
277 let config = DownloadConfig::new("https://example.com/file.zip", "/tmp/file.zip")
278 .with_checksum("abc123")
279 .with_max_retries(5)
280 .with_overwrite(true);
281
282 assert_eq!(config.url, "https://example.com/file.zip");
283 assert_eq!(config.output_path, PathBuf::from("/tmp/file.zip"));
284 assert_eq!(config.checksum, Some("abc123".to_string()));
285 assert_eq!(config.max_retries, 5);
286 assert!(config.overwrite);
287 }
288}