1use std::collections::HashMap;
6use std::fs::{self, File};
7use std::io::{self, Read, Write};
8use std::path::{Path, PathBuf};
9
10use futures::stream::{self, StreamExt};
11use sha2::{Digest, Sha256};
12use zip::ZipArchive;
13
14use crate::lockfile::LockedPackage;
15use crate::{Error, Result};
16
17const MAX_CONCURRENT_DOWNLOADS: usize = 8;
19
20pub struct Installer {
22 cache_dir: PathBuf,
24 client: reqwest::Client,
26}
27
28#[derive(Debug)]
30pub struct InstallResult {
31 pub name: String,
33 pub installed: bool,
35 pub downloaded: bool,
37 pub error: Option<String>,
39}
40
41impl Installer {
42 pub fn new(cache_dir: impl Into<PathBuf>) -> Self {
44 Self {
45 cache_dir: cache_dir.into(),
46 client: reqwest::Client::builder()
47 .user_agent("Pro/0.1.0")
48 .build()
49 .expect("Failed to create HTTP client"),
50 }
51 }
52
53 pub async fn install(
55 &self,
56 packages: &HashMap<String, LockedPackage>,
57 site_packages: &Path,
58 ) -> Result<Vec<InstallResult>> {
59 fs::create_dir_all(&self.cache_dir).map_err(Error::Io)?;
61 fs::create_dir_all(site_packages).map_err(Error::Io)?;
62
63 let download_tasks: Vec<_> = packages
65 .iter()
66 .map(|(name, pkg)| self.download_package(name, pkg))
67 .collect();
68
69 let download_results: Vec<_> = stream::iter(download_tasks)
70 .buffer_unordered(MAX_CONCURRENT_DOWNLOADS)
71 .collect()
72 .await;
73
74 let mut results = Vec::new();
76 for name in packages.keys() {
77 let download_result = download_results.iter().find(|(n, _, _)| n == name);
78
79 match download_result {
80 Some((_, Some(cached_path), downloaded)) => {
81 match self.install_wheel(cached_path, site_packages) {
82 Ok(()) => {
83 results.push(InstallResult {
84 name: name.clone(),
85 installed: true,
86 downloaded: *downloaded,
87 error: None,
88 });
89 }
90 Err(e) => {
91 results.push(InstallResult {
92 name: name.clone(),
93 installed: false,
94 downloaded: *downloaded,
95 error: Some(e.to_string()),
96 });
97 }
98 }
99 }
100 Some((_, None, _)) => {
101 results.push(InstallResult {
102 name: name.clone(),
103 installed: false,
104 downloaded: false,
105 error: Some("No download URL available".into()),
106 });
107 }
108 None => {
109 results.push(InstallResult {
110 name: name.clone(),
111 installed: false,
112 downloaded: false,
113 error: Some("Download task not found".into()),
114 });
115 }
116 }
117 }
118
119 Ok(results)
120 }
121
122 async fn download_package(
124 &self,
125 name: &str,
126 pkg: &LockedPackage,
127 ) -> (String, Option<PathBuf>, bool) {
128 let url = match &pkg.url {
129 Some(u) if !u.is_empty() => u,
130 _ => return (name.to_string(), None, false),
131 };
132
133 let filename = url.rsplit('/').next().unwrap_or("package.whl");
135 let cached_path = self.cache_dir.join(filename);
136
137 if cached_path.exists() {
139 if let Some(expected_hash) = &pkg.hash {
140 if let Ok(actual_hash) = compute_file_hash(&cached_path) {
141 if verify_hash(&actual_hash, expected_hash) {
142 tracing::debug!("Using cached: {}", filename);
143 return (name.to_string(), Some(cached_path), false);
144 }
145 }
146 } else {
147 tracing::debug!("Using cached (no hash): {}", filename);
149 return (name.to_string(), Some(cached_path), false);
150 }
151 }
152
153 tracing::info!("Downloading {} from {}", name, url);
155 match self.download_file(url, &cached_path).await {
156 Ok(()) => {
157 if let Some(expected_hash) = &pkg.hash {
159 match compute_file_hash(&cached_path) {
160 Ok(actual_hash) => {
161 if !verify_hash(&actual_hash, expected_hash) {
162 tracing::error!(
163 "Hash mismatch for {}: expected {}, got sha256:{}",
164 name,
165 expected_hash,
166 actual_hash
167 );
168 let _ = fs::remove_file(&cached_path);
169 return (name.to_string(), None, false);
170 }
171 tracing::debug!("Hash verified for {}", name);
172 }
173 Err(e) => {
174 tracing::error!("Failed to compute hash for {}: {}", name, e);
175 return (name.to_string(), None, false);
176 }
177 }
178 }
179 (name.to_string(), Some(cached_path), true)
180 }
181 Err(e) => {
182 tracing::error!("Failed to download {}: {}", name, e);
183 (name.to_string(), None, false)
184 }
185 }
186 }
187
188 async fn download_file(&self, url: &str, dest: &Path) -> Result<()> {
190 let response = self.client.get(url).send().await.map_err(Error::Network)?;
191
192 if !response.status().is_success() {
193 return Err(Error::Index(format!(
194 "Failed to download {}: HTTP {}",
195 url,
196 response.status()
197 )));
198 }
199
200 let bytes = response.bytes().await.map_err(Error::Network)?;
201
202 let mut file = File::create(dest).map_err(Error::Io)?;
203 file.write_all(&bytes).map_err(Error::Io)?;
204
205 Ok(())
206 }
207
208 fn install_wheel(&self, wheel_path: &Path, site_packages: &Path) -> Result<()> {
210 tracing::debug!("Installing wheel: {:?}", wheel_path);
211
212 let file = File::open(wheel_path).map_err(Error::Io)?;
213 let mut archive = ZipArchive::new(file)
214 .map_err(|e| Error::BuildError(format!("Invalid wheel: {}", e)))?;
215
216 for i in 0..archive.len() {
217 let mut entry = archive
218 .by_index(i)
219 .map_err(|e| Error::BuildError(format!("Failed to read wheel entry: {}", e)))?;
220
221 let entry_path = entry
222 .enclosed_name()
223 .ok_or_else(|| Error::BuildError("Invalid entry path in wheel".into()))?;
224
225 let dest_path = site_packages.join(&entry_path);
226
227 if entry.is_dir() {
228 fs::create_dir_all(&dest_path).map_err(Error::Io)?;
229 } else {
230 if let Some(parent) = dest_path.parent() {
232 fs::create_dir_all(parent).map_err(Error::Io)?;
233 }
234
235 let mut outfile = File::create(&dest_path).map_err(Error::Io)?;
237 io::copy(&mut entry, &mut outfile).map_err(Error::Io)?;
238
239 #[cfg(unix)]
241 {
242 use std::os::unix::fs::PermissionsExt;
243 if entry_path.starts_with("..")
244 || entry_path.to_string_lossy().contains("/bin/")
245 {
246 let mut perms = fs::metadata(&dest_path).map_err(Error::Io)?.permissions();
247 perms.set_mode(0o755);
248 fs::set_permissions(&dest_path, perms).map_err(Error::Io)?;
249 }
250 }
251 }
252 }
253
254 Ok(())
255 }
256
257 pub fn cache_dir(&self) -> &Path {
259 &self.cache_dir
260 }
261
262 pub fn clear_cache(&self) -> Result<()> {
264 if self.cache_dir.exists() {
265 fs::remove_dir_all(&self.cache_dir).map_err(Error::Io)?;
266 }
267 Ok(())
268 }
269}
270
271fn compute_file_hash(path: &Path) -> Result<String> {
273 let mut file = File::open(path).map_err(Error::Io)?;
274 let mut hasher = Sha256::new();
275 let mut buffer = [0u8; 8192];
276
277 loop {
278 let bytes_read = file.read(&mut buffer).map_err(Error::Io)?;
279 if bytes_read == 0 {
280 break;
281 }
282 hasher.update(&buffer[..bytes_read]);
283 }
284
285 Ok(hex::encode(hasher.finalize()))
286}
287
288fn verify_hash(actual: &str, expected: &str) -> bool {
290 let expected_hash = expected
291 .strip_prefix("sha256:")
292 .unwrap_or(expected)
293 .to_lowercase();
294 actual.to_lowercase() == expected_hash
295}
296
297pub fn default_cache_dir() -> PathBuf {
299 #[cfg(unix)]
302 {
303 if let Ok(xdg_cache) = std::env::var("XDG_CACHE_HOME") {
304 return PathBuf::from(xdg_cache).join("rx").join("wheels");
305 }
306 if let Ok(home) = std::env::var("HOME") {
307 return PathBuf::from(home).join(".cache").join("rx").join("wheels");
308 }
309 }
310
311 #[cfg(windows)]
312 {
313 if let Ok(local_app_data) = std::env::var("LOCALAPPDATA") {
314 return PathBuf::from(local_app_data)
315 .join("rx")
316 .join("cache")
317 .join("wheels");
318 }
319 }
320
321 PathBuf::from("/tmp/rx-cache/wheels")
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328 use std::io::Write;
329 use tempfile::tempdir;
330
331 #[test]
332 fn test_compute_file_hash() {
333 let dir = tempdir().unwrap();
334 let file_path = dir.path().join("test.txt");
335
336 let mut file = File::create(&file_path).unwrap();
337 file.write_all(b"hello world").unwrap();
338 drop(file);
339
340 let hash = compute_file_hash(&file_path).unwrap();
341 assert_eq!(
343 hash,
344 "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"
345 );
346 }
347
348 #[test]
349 fn test_verify_hash() {
350 let hash = "abc123def456";
351 assert!(verify_hash(hash, "sha256:abc123def456"));
352 assert!(verify_hash(hash, "ABC123DEF456"));
353 assert!(!verify_hash(hash, "sha256:different"));
354 }
355
356 #[test]
357 fn test_installer_new() {
358 let installer = Installer::new("/tmp/test-cache");
359 assert_eq!(installer.cache_dir(), Path::new("/tmp/test-cache"));
360 }
361
362 #[test]
363 fn test_default_cache_dir() {
364 let cache_dir = default_cache_dir();
365 assert!(cache_dir.to_string_lossy().contains("rx"));
366 }
367}