Skip to main content

pro_core/installer/
mod.rs

1//! Package installer with caching and parallel downloads
2//!
3//! Downloads wheels, verifies hashes, and unpacks to site-packages.
4
5use std::collections::HashMap;
6use std::fs::{self, File};
7use std::io::{self, Read, Write};
8use std::path::{Path, PathBuf};
9
10use futures::stream::{self, StreamExt};
11use sha2::{Digest, Sha256};
12use zip::ZipArchive;
13
14use crate::lockfile::LockedPackage;
15use crate::{Error, Result};
16
17/// Maximum concurrent downloads
18const MAX_CONCURRENT_DOWNLOADS: usize = 8;
19
20/// Package installer
21pub struct Installer {
22    /// Cache directory for downloaded wheels
23    cache_dir: PathBuf,
24    /// HTTP client
25    client: reqwest::Client,
26}
27
28/// Result of installing a package
29#[derive(Debug)]
30pub struct InstallResult {
31    /// Package name
32    pub name: String,
33    /// Whether it was installed (false if already cached)
34    pub installed: bool,
35    /// Whether it was cached (downloaded this run)
36    pub downloaded: bool,
37    /// Error message if failed
38    pub error: Option<String>,
39}
40
41impl Installer {
42    /// Create a new installer with the given cache directory
43    pub fn new(cache_dir: impl Into<PathBuf>) -> Self {
44        Self {
45            cache_dir: cache_dir.into(),
46            client: reqwest::Client::builder()
47                .user_agent("Pro/0.1.0")
48                .build()
49                .expect("Failed to create HTTP client"),
50        }
51    }
52
53    /// Install packages from lockfile into site-packages
54    pub async fn install(
55        &self,
56        packages: &HashMap<String, LockedPackage>,
57        site_packages: &Path,
58    ) -> Result<Vec<InstallResult>> {
59        // Ensure cache directory exists
60        fs::create_dir_all(&self.cache_dir).map_err(Error::Io)?;
61        fs::create_dir_all(site_packages).map_err(Error::Io)?;
62
63        // Download packages in parallel
64        let download_tasks: Vec<_> = packages
65            .iter()
66            .map(|(name, pkg)| self.download_package(name, pkg))
67            .collect();
68
69        let download_results: Vec<_> = stream::iter(download_tasks)
70            .buffer_unordered(MAX_CONCURRENT_DOWNLOADS)
71            .collect()
72            .await;
73
74        // Install each downloaded package
75        let mut results = Vec::new();
76        for name in packages.keys() {
77            let download_result = download_results.iter().find(|(n, _, _)| n == name);
78
79            match download_result {
80                Some((_, Some(cached_path), downloaded)) => {
81                    match self.install_wheel(cached_path, site_packages) {
82                        Ok(()) => {
83                            results.push(InstallResult {
84                                name: name.clone(),
85                                installed: true,
86                                downloaded: *downloaded,
87                                error: None,
88                            });
89                        }
90                        Err(e) => {
91                            results.push(InstallResult {
92                                name: name.clone(),
93                                installed: false,
94                                downloaded: *downloaded,
95                                error: Some(e.to_string()),
96                            });
97                        }
98                    }
99                }
100                Some((_, None, _)) => {
101                    results.push(InstallResult {
102                        name: name.clone(),
103                        installed: false,
104                        downloaded: false,
105                        error: Some("No download URL available".into()),
106                    });
107                }
108                None => {
109                    results.push(InstallResult {
110                        name: name.clone(),
111                        installed: false,
112                        downloaded: false,
113                        error: Some("Download task not found".into()),
114                    });
115                }
116            }
117        }
118
119        Ok(results)
120    }
121
122    /// Download a package and return the cached path
123    async fn download_package(
124        &self,
125        name: &str,
126        pkg: &LockedPackage,
127    ) -> (String, Option<PathBuf>, bool) {
128        let url = match &pkg.url {
129            Some(u) if !u.is_empty() => u,
130            _ => return (name.to_string(), None, false),
131        };
132
133        // Determine cache path from URL
134        let filename = url.rsplit('/').next().unwrap_or("package.whl");
135        let cached_path = self.cache_dir.join(filename);
136
137        // Check if already cached with valid hash
138        if cached_path.exists() {
139            if let Some(expected_hash) = &pkg.hash {
140                if let Ok(actual_hash) = compute_file_hash(&cached_path) {
141                    if verify_hash(&actual_hash, expected_hash) {
142                        tracing::debug!("Using cached: {}", filename);
143                        return (name.to_string(), Some(cached_path), false);
144                    }
145                }
146            } else {
147                // No hash to verify, assume cached is good
148                tracing::debug!("Using cached (no hash): {}", filename);
149                return (name.to_string(), Some(cached_path), false);
150            }
151        }
152
153        // Download the file
154        tracing::info!("Downloading {} from {}", name, url);
155        match self.download_file(url, &cached_path).await {
156            Ok(()) => {
157                // Verify hash if provided
158                if let Some(expected_hash) = &pkg.hash {
159                    match compute_file_hash(&cached_path) {
160                        Ok(actual_hash) => {
161                            if !verify_hash(&actual_hash, expected_hash) {
162                                tracing::error!(
163                                    "Hash mismatch for {}: expected {}, got sha256:{}",
164                                    name,
165                                    expected_hash,
166                                    actual_hash
167                                );
168                                let _ = fs::remove_file(&cached_path);
169                                return (name.to_string(), None, false);
170                            }
171                            tracing::debug!("Hash verified for {}", name);
172                        }
173                        Err(e) => {
174                            tracing::error!("Failed to compute hash for {}: {}", name, e);
175                            return (name.to_string(), None, false);
176                        }
177                    }
178                }
179                (name.to_string(), Some(cached_path), true)
180            }
181            Err(e) => {
182                tracing::error!("Failed to download {}: {}", name, e);
183                (name.to_string(), None, false)
184            }
185        }
186    }
187
188    /// Download a file to the given path
189    async fn download_file(&self, url: &str, dest: &Path) -> Result<()> {
190        let response = self.client.get(url).send().await.map_err(Error::Network)?;
191
192        if !response.status().is_success() {
193            return Err(Error::Index(format!(
194                "Failed to download {}: HTTP {}",
195                url,
196                response.status()
197            )));
198        }
199
200        let bytes = response.bytes().await.map_err(Error::Network)?;
201
202        let mut file = File::create(dest).map_err(Error::Io)?;
203        file.write_all(&bytes).map_err(Error::Io)?;
204
205        Ok(())
206    }
207
208    /// Install a wheel file to site-packages
209    fn install_wheel(&self, wheel_path: &Path, site_packages: &Path) -> Result<()> {
210        tracing::debug!("Installing wheel: {:?}", wheel_path);
211
212        let file = File::open(wheel_path).map_err(Error::Io)?;
213        let mut archive = ZipArchive::new(file)
214            .map_err(|e| Error::BuildError(format!("Invalid wheel: {}", e)))?;
215
216        for i in 0..archive.len() {
217            let mut entry = archive
218                .by_index(i)
219                .map_err(|e| Error::BuildError(format!("Failed to read wheel entry: {}", e)))?;
220
221            let entry_path = entry
222                .enclosed_name()
223                .ok_or_else(|| Error::BuildError("Invalid entry path in wheel".into()))?;
224
225            let dest_path = site_packages.join(&entry_path);
226
227            if entry.is_dir() {
228                fs::create_dir_all(&dest_path).map_err(Error::Io)?;
229            } else {
230                // Create parent directories
231                if let Some(parent) = dest_path.parent() {
232                    fs::create_dir_all(parent).map_err(Error::Io)?;
233                }
234
235                // Extract file
236                let mut outfile = File::create(&dest_path).map_err(Error::Io)?;
237                io::copy(&mut entry, &mut outfile).map_err(Error::Io)?;
238
239                // Set executable permissions for scripts
240                #[cfg(unix)]
241                {
242                    use std::os::unix::fs::PermissionsExt;
243                    if entry_path.starts_with("..")
244                        || entry_path.to_string_lossy().contains("/bin/")
245                    {
246                        let mut perms = fs::metadata(&dest_path).map_err(Error::Io)?.permissions();
247                        perms.set_mode(0o755);
248                        fs::set_permissions(&dest_path, perms).map_err(Error::Io)?;
249                    }
250                }
251            }
252        }
253
254        Ok(())
255    }
256
257    /// Get the cache directory
258    pub fn cache_dir(&self) -> &Path {
259        &self.cache_dir
260    }
261
262    /// Clear the cache
263    pub fn clear_cache(&self) -> Result<()> {
264        if self.cache_dir.exists() {
265            fs::remove_dir_all(&self.cache_dir).map_err(Error::Io)?;
266        }
267        Ok(())
268    }
269}
270
271/// Compute SHA256 hash of a file
272fn compute_file_hash(path: &Path) -> Result<String> {
273    let mut file = File::open(path).map_err(Error::Io)?;
274    let mut hasher = Sha256::new();
275    let mut buffer = [0u8; 8192];
276
277    loop {
278        let bytes_read = file.read(&mut buffer).map_err(Error::Io)?;
279        if bytes_read == 0 {
280            break;
281        }
282        hasher.update(&buffer[..bytes_read]);
283    }
284
285    Ok(hex::encode(hasher.finalize()))
286}
287
288/// Verify a hash against expected (handles "sha256:..." prefix)
289fn verify_hash(actual: &str, expected: &str) -> bool {
290    let expected_hash = expected
291        .strip_prefix("sha256:")
292        .unwrap_or(expected)
293        .to_lowercase();
294    actual.to_lowercase() == expected_hash
295}
296
297/// Get the default cache directory
298pub fn default_cache_dir() -> PathBuf {
299    // Use XDG_CACHE_HOME or ~/.cache on Unix
300    // Use %LOCALAPPDATA% on Windows
301    #[cfg(unix)]
302    {
303        if let Ok(xdg_cache) = std::env::var("XDG_CACHE_HOME") {
304            return PathBuf::from(xdg_cache).join("rx").join("wheels");
305        }
306        if let Ok(home) = std::env::var("HOME") {
307            return PathBuf::from(home).join(".cache").join("rx").join("wheels");
308        }
309    }
310
311    #[cfg(windows)]
312    {
313        if let Ok(local_app_data) = std::env::var("LOCALAPPDATA") {
314            return PathBuf::from(local_app_data)
315                .join("rx")
316                .join("cache")
317                .join("wheels");
318        }
319    }
320
321    // Fallback
322    PathBuf::from("/tmp/rx-cache/wheels")
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328    use std::io::Write;
329    use tempfile::tempdir;
330
331    #[test]
332    fn test_compute_file_hash() {
333        let dir = tempdir().unwrap();
334        let file_path = dir.path().join("test.txt");
335
336        let mut file = File::create(&file_path).unwrap();
337        file.write_all(b"hello world").unwrap();
338        drop(file);
339
340        let hash = compute_file_hash(&file_path).unwrap();
341        // SHA256 of "hello world"
342        assert_eq!(
343            hash,
344            "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"
345        );
346    }
347
348    #[test]
349    fn test_verify_hash() {
350        let hash = "abc123def456";
351        assert!(verify_hash(hash, "sha256:abc123def456"));
352        assert!(verify_hash(hash, "ABC123DEF456"));
353        assert!(!verify_hash(hash, "sha256:different"));
354    }
355
356    #[test]
357    fn test_installer_new() {
358        let installer = Installer::new("/tmp/test-cache");
359        assert_eq!(installer.cache_dir(), Path::new("/tmp/test-cache"));
360    }
361
362    #[test]
363    fn test_default_cache_dir() {
364        let cache_dir = default_cache_dir();
365        assert!(cache_dir.to_string_lossy().contains("rx"));
366    }
367}