Skip to main content

pro_core/resolver/
mod.rs

1//! Dependency resolver using pubgrub algorithm
2
3mod package;
4mod provider;
5
6pub use package::Package;
7pub use provider::{PyPIProvider, ROOT_PACKAGE};
8
9use std::sync::Arc;
10
11use pubgrub::error::PubGrubError;
12use pubgrub::report::{DefaultStringReporter, Reporter};
13use pubgrub::solver::resolve;
14use tracing::{debug, info, instrument, warn};
15
16use crate::index::{FileInfo, PyPIClient};
17use crate::pep::{Requirement, Version};
18use crate::{Error, Result};
19
20/// Dependency resolver for Python packages
21pub struct Resolver {
22    /// PyPI client for fetching metadata
23    client: Arc<PyPIClient>,
24}
25
26impl Resolver {
27    /// Create a new resolver with default PyPI client
28    pub fn new() -> Self {
29        Self {
30            client: Arc::new(PyPIClient::new()),
31        }
32    }
33
34    /// Create a new resolver with a custom client
35    pub fn with_client(client: PyPIClient) -> Self {
36        Self {
37            client: Arc::new(client),
38        }
39    }
40
41    /// Resolve dependencies for a list of requirements
42    #[instrument(skip(self, requirements))]
43    pub async fn resolve(&self, requirements: &[Requirement]) -> Result<Resolution> {
44        if requirements.is_empty() {
45            return Ok(Resolution { packages: vec![] });
46        }
47
48        info!("resolving {} requirements", requirements.len());
49
50        // Phase 1: Pre-fetch metadata for all direct requirements
51        debug!("fetching metadata for direct dependencies");
52        let mut provider = PyPIProvider::build(&self.client, requirements).await?;
53
54        // Phase 2: Pre-crawl dependency graph to discover all needed packages
55        // We need all metadata BEFORE running pubgrub
56        let mut iteration = 0;
57        const MAX_ITERATIONS: usize = 20;
58
59        loop {
60            iteration += 1;
61            if iteration > MAX_ITERATIONS {
62                return Err(Error::Resolution(
63                    "too many iterations discovering dependencies".to_string(),
64                ));
65            }
66
67            // Discover all packages referenced in dependencies
68            let missing = provider.discover_all_packages();
69            if missing.is_empty() {
70                break;
71            }
72
73            debug!(
74                "iteration {}: fetching {} missing packages",
75                iteration,
76                missing.len()
77            );
78
79            // Fetch missing metadata
80            let results = self.client.get_packages_concurrent(&missing).await;
81            for (name, result) in results {
82                match result {
83                    Ok(meta) => provider.add_metadata(name, meta),
84                    Err(Error::PackageNotFound { .. }) => {
85                        warn!("dependency {} not found on PyPI, skipping", name);
86                    }
87                    Err(e) => return Err(e),
88                }
89            }
90        }
91
92        debug!("all dependencies discovered, running pubgrub solver");
93
94        // Phase 3: Run pubgrub solver with complete metadata
95        let root = Package::new(ROOT_PACKAGE);
96        let solution = match resolve(&provider, root.clone(), Version::new(vec![1, 0, 0])) {
97            Ok(sol) => sol,
98            Err(PubGrubError::NoSolution(mut tree)) => {
99                tree.collapse_no_versions();
100                let msg = DefaultStringReporter::report(&tree);
101                return Err(Error::Resolution(msg));
102            }
103            Err(PubGrubError::ErrorChoosingPackageVersion(e)) => {
104                return Err(Error::Resolution(format!("error choosing version: {}", e)));
105            }
106            Err(PubGrubError::ErrorRetrievingDependencies {
107                package,
108                version,
109                source,
110            }) => {
111                return Err(Error::Resolution(format!(
112                    "error getting dependencies for {} {}: {}",
113                    package, version, source
114                )));
115            }
116            Err(PubGrubError::SelfDependency { package, version }) => {
117                return Err(Error::Resolution(format!(
118                    "package {} {} depends on itself",
119                    package, version
120                )));
121            }
122            Err(PubGrubError::DependencyOnTheEmptySet {
123                package,
124                version,
125                dependent,
126            }) => {
127                return Err(Error::Resolution(format!(
128                    "package {} {} has impossible dependency on {}",
129                    package, version, dependent
130                )));
131            }
132            Err(PubGrubError::Failure(msg)) => {
133                return Err(Error::Resolution(msg));
134            }
135            Err(PubGrubError::ErrorInShouldCancel(e)) => {
136                return Err(Error::Resolution(format!("resolution cancelled: {}", e)));
137            }
138        };
139
140        self.build_resolution(&solution).await
141    }
142
143    /// Build the final resolution from the pubgrub solution
144    async fn build_resolution(
145        &self,
146        solution: &pubgrub::type_aliases::SelectedDependencies<Package, Version>,
147    ) -> Result<Resolution> {
148        let mut packages = Vec::new();
149
150        for (package, version) in solution {
151            // Skip the root package
152            if package.name == ROOT_PACKAGE {
153                continue;
154            }
155
156            // Fetch metadata for this specific version to get file info
157            let metadata = self.client.get_package(&package.name).await?;
158            let version_str = version.to_string();
159
160            // Find the best file for this version
161            let release_files = metadata
162                .releases
163                .get(&version_str)
164                .cloned()
165                .unwrap_or_default();
166            let file = Self::select_best_file(&release_files);
167
168            let (url, hash) = match file {
169                Some(f) => {
170                    let hash = f
171                        .best_hash()
172                        .map(|(algo, h)| format!("{}:{}", algo, h))
173                        .unwrap_or_default();
174                    (f.url.clone(), hash)
175                }
176                None => {
177                    warn!("no suitable file found for {}=={}", package.name, version);
178                    (String::new(), String::new())
179                }
180            };
181
182            // Extract dependencies from metadata
183            let dependencies = Self::extract_dependencies(&metadata, &version_str);
184
185            // Build platform-specific files list
186            let platform_files = Self::build_platform_files(&release_files);
187
188            packages.push(ResolvedPackage {
189                name: package.name.clone(),
190                version: version_str,
191                url,
192                hash,
193                dependencies,
194                markers: None, // TODO: Extract from requires_dist
195                files: platform_files,
196            });
197        }
198
199        // Sort packages alphabetically for consistent output
200        packages.sort_by(|a, b| a.name.cmp(&b.name));
201
202        info!("resolved {} packages", packages.len());
203        Ok(Resolution { packages })
204    }
205
206    /// Extract dependency names from package metadata
207    fn extract_dependencies(
208        metadata: &crate::index::PackageMetadata,
209        _version: &str,
210    ) -> Vec<String> {
211        // Get requires_dist from package info
212        if let Some(requires_dist) = &metadata.info.requires_dist {
213            return requires_dist
214                .iter()
215                .filter_map(|req| {
216                    // Parse the requirement to get the package name
217                    // Format: "package-name (>=1.0)" or "package-name; extra == 'dev'"
218                    let name = req.split([' ', ';', '[', '(']).next().unwrap_or(req).trim();
219                    if name.is_empty() {
220                        None
221                    } else {
222                        Some(Package::new(name).name)
223                    }
224                })
225                .collect();
226        }
227        Vec::new()
228    }
229
230    /// Build platform-specific file entries
231    fn build_platform_files(files: &[FileInfo]) -> Vec<ResolvedFile> {
232        files
233            .iter()
234            .filter(|f| !f.yanked && f.is_wheel())
235            .filter_map(|f| {
236                let hash = f
237                    .best_hash()
238                    .map(|(algo, h)| format!("{}:{}", algo, h))
239                    .unwrap_or_default();
240
241                if hash.is_empty() {
242                    return None;
243                }
244
245                let tags = f.parse_wheel_tags();
246                let (markers, python, tag_str) = match &tags {
247                    Some(t) => {
248                        let markers = Self::tags_to_markers(t);
249                        let python = if t.python.contains("py3") {
250                            Some(">=3.0".to_string())
251                        } else if t.python.contains("py2") {
252                            Some("<3.0".to_string())
253                        } else {
254                            None
255                        };
256                        let tag_str = format!("{}-{}-{}", t.python, t.abi, t.platform);
257                        (markers, python, Some(tag_str))
258                    }
259                    None => (None, None, None),
260                };
261
262                Some(ResolvedFile {
263                    url: f.url.clone(),
264                    hash,
265                    markers,
266                    python,
267                    tags: tag_str,
268                })
269            })
270            .collect()
271    }
272
273    /// Convert wheel tags to platform markers
274    fn tags_to_markers(tags: &crate::index::WheelTags) -> Option<String> {
275        if tags.is_universal() {
276            return None; // Universal wheel, no markers needed
277        }
278
279        let mut markers = Vec::new();
280
281        // Platform markers
282        if tags.platform.contains("win") {
283            markers.push("sys_platform == 'win32'".to_string());
284        } else if tags.platform.contains("macosx") || tags.platform.contains("darwin") {
285            markers.push("sys_platform == 'darwin'".to_string());
286        } else if tags.platform.contains("linux") {
287            markers.push("sys_platform == 'linux'".to_string());
288        }
289
290        if markers.is_empty() {
291            None
292        } else {
293            Some(markers.join(" and "))
294        }
295    }
296
297    /// Select the best file from available files
298    /// Prefers wheels over sdists, and universal wheels over platform-specific
299    fn select_best_file(files: &[FileInfo]) -> Option<&FileInfo> {
300        // Filter out yanked files
301        let available: Vec<_> = files.iter().filter(|f| !f.yanked).collect();
302
303        if available.is_empty() {
304            return None;
305        }
306
307        // First, try to find a universal wheel (py3-none-any)
308        for file in &available {
309            if file.is_wheel() {
310                if let Some(tags) = file.parse_wheel_tags() {
311                    if tags.is_universal() && tags.python.contains("py3") {
312                        return Some(file);
313                    }
314                }
315            }
316        }
317
318        // Next, try any wheel
319        for file in &available {
320            if file.is_wheel() {
321                return Some(file);
322            }
323        }
324
325        // Finally, fall back to sdist
326        for file in &available {
327            if file.is_sdist() {
328                return Some(file);
329            }
330        }
331
332        // Return first available as last resort
333        available.first().copied()
334    }
335}
336
337impl Default for Resolver {
338    fn default() -> Self {
339        Self::new()
340    }
341}
342
343/// Result of dependency resolution
344#[derive(Debug, Clone)]
345pub struct Resolution {
346    /// Resolved packages with their versions
347    pub packages: Vec<ResolvedPackage>,
348}
349
350impl Resolution {
351    /// Get a package by name
352    pub fn get(&self, name: &str) -> Option<&ResolvedPackage> {
353        let normalized = Package::new(name).name;
354        self.packages
355            .iter()
356            .find(|p| Package::new(&p.name).name == normalized)
357    }
358
359    /// Check if a package is in the resolution
360    pub fn contains(&self, name: &str) -> bool {
361        self.get(name).is_some()
362    }
363
364    /// Number of resolved packages
365    pub fn len(&self) -> usize {
366        self.packages.len()
367    }
368
369    /// Check if resolution is empty
370    pub fn is_empty(&self) -> bool {
371        self.packages.is_empty()
372    }
373}
374
375/// A resolved package with its locked version
376#[derive(Debug, Clone)]
377pub struct ResolvedPackage {
378    /// Package name (normalized)
379    pub name: String,
380    /// Resolved version
381    pub version: String,
382    /// Download URL (default/universal)
383    pub url: String,
384    /// Hash (format: "algorithm:hash")
385    pub hash: String,
386    /// Direct dependencies (normalized names)
387    pub dependencies: Vec<String>,
388    /// Platform markers (PEP 508)
389    pub markers: Option<String>,
390    /// Platform-specific files
391    pub files: Vec<ResolvedFile>,
392}
393
394/// A resolved file with platform info
395#[derive(Debug, Clone)]
396pub struct ResolvedFile {
397    /// Download URL
398    pub url: String,
399    /// Hash
400    pub hash: String,
401    /// Platform markers
402    pub markers: Option<String>,
403    /// Python version constraint
404    pub python: Option<String>,
405    /// Wheel tags
406    pub tags: Option<String>,
407}
408
409impl ResolvedPackage {
410    /// Parse the hash into algorithm and value
411    pub fn parse_hash(&self) -> Option<(&str, &str)> {
412        self.hash.split_once(':')
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419
420    #[test]
421    fn test_resolution_get() {
422        let resolution = Resolution {
423            packages: vec![ResolvedPackage {
424                name: "requests".to_string(),
425                version: "2.28.0".to_string(),
426                url: "".to_string(),
427                hash: "sha256:abc".to_string(),
428                dependencies: vec![],
429                markers: None,
430                files: vec![],
431            }],
432        };
433
434        assert!(resolution.contains("requests"));
435        assert!(resolution.contains("Requests")); // Normalized
436        assert!(!resolution.contains("urllib3"));
437    }
438
439    #[test]
440    fn test_parse_hash() {
441        let pkg = ResolvedPackage {
442            name: "test".to_string(),
443            version: "1.0.0".to_string(),
444            url: "".to_string(),
445            hash: "sha256:abc123".to_string(),
446            dependencies: vec![],
447            markers: None,
448            files: vec![],
449        };
450
451        let (algo, hash) = pkg.parse_hash().unwrap();
452        assert_eq!(algo, "sha256");
453        assert_eq!(hash, "abc123");
454    }
455
456    #[tokio::test]
457    #[ignore = "requires network"]
458    async fn test_resolve_requests() {
459        let resolver = Resolver::new();
460        let requirements = vec![Requirement::parse("requests>=2.28.0").unwrap()];
461
462        let resolution = resolver.resolve(&requirements).await.unwrap();
463
464        assert!(resolution.contains("requests"));
465        // requests should pull in urllib3, charset-normalizer, idna, certifi
466        assert!(resolution.contains("urllib3"));
467        assert!(resolution.contains("certifi"));
468    }
469}