Skip to main content

pro_core/resolver/
provider.rs

1//! pubgrub DependencyProvider implementation for PyPI packages
2
3use std::borrow::Borrow;
4use std::cell::RefCell;
5use std::collections::{HashMap, HashSet};
6use std::error::Error as StdError;
7
8use pubgrub::range::Range;
9use pubgrub::solver::{Dependencies, DependencyProvider};
10
11use crate::index::{PackageMetadata, PyPIClient};
12use crate::pep::{Requirement, Version, VersionSpecifiers};
13use crate::resolver::Package;
14use crate::Error;
15
16/// Special package representing the root project
17/// Note: This name is chosen to survive Package::normalize() unchanged
18pub const ROOT_PACKAGE: &str = "root-project";
19
20/// PyPI-based dependency provider for pubgrub
21pub struct PyPIProvider {
22    /// Cached package metadata (package name -> metadata)
23    metadata_cache: HashMap<String, PackageMetadata>,
24    /// Cached parsed versions (package name -> sorted versions, newest first)
25    versions_cache: HashMap<String, Vec<Version>>,
26    /// Packages we've seen during resolution (interior mutability for get_dependencies)
27    seen_packages: RefCell<HashSet<String>>,
28    /// Root package dependencies
29    root_deps: Vec<(Package, Range<Version>)>,
30}
31
32impl PyPIProvider {
33    /// Create a new provider with pre-fetched metadata
34    pub fn new(
35        metadata: HashMap<String, PackageMetadata>,
36        root_deps: Vec<(Package, Range<Version>)>,
37    ) -> Self {
38        let mut versions_cache = HashMap::new();
39
40        // Pre-parse versions for all packages
41        for (name, meta) in &metadata {
42            let mut versions: Vec<Version> = meta
43                .releases
44                .iter()
45                .filter(|(_, files)| files.iter().any(|f| !f.yanked && !files.is_empty()))
46                .filter_map(|(v, _)| Version::parse(v).ok())
47                .collect();
48
49            // Sort newest first (pubgrub expects this for choose_version)
50            versions.sort_by(|a, b| b.cmp(a));
51            versions_cache.insert(name.clone(), versions);
52        }
53
54        // Initialize seen packages with root deps
55        let mut seen = HashSet::new();
56        for (pkg, _) in &root_deps {
57            seen.insert(pkg.name.clone());
58        }
59
60        Self {
61            metadata_cache: metadata,
62            versions_cache,
63            seen_packages: RefCell::new(seen),
64            root_deps,
65        }
66    }
67
68    /// Build the provider by pre-fetching all required metadata
69    pub async fn build(client: &PyPIClient, requirements: &[Requirement]) -> Result<Self, Error> {
70        // Collect all package names we need to fetch
71        let names: Vec<String> = requirements.iter().map(|r| r.name.clone()).collect();
72
73        // Fetch all metadata concurrently
74        let results = client.get_packages_concurrent(&names).await;
75
76        // Collect successful fetches and errors
77        let mut metadata = HashMap::new();
78        for (name, result) in results {
79            match result {
80                Ok(meta) => {
81                    metadata.insert(name, meta);
82                }
83                Err(e) => {
84                    return Err(e);
85                }
86            }
87        }
88
89        // Parse root dependencies
90        let root_deps = requirements
91            .iter()
92            .map(|r| {
93                let pkg = Package::new(&r.name);
94                let range = if let Some(ref spec) = r.specifier {
95                    VersionSpecifiers::parse(spec)
96                        .map(|s| s.to_pubgrub_range())
97                        .unwrap_or_else(|_| Range::any())
98                } else {
99                    Range::any()
100                };
101                (pkg, range)
102            })
103            .collect();
104
105        Ok(Self::new(metadata, root_deps))
106    }
107
108    /// Add more metadata to the provider (for transitive dependencies)
109    pub fn add_metadata(&mut self, name: String, metadata: PackageMetadata) {
110        let mut versions: Vec<Version> = metadata
111            .releases
112            .iter()
113            .filter(|(_, files)| files.iter().any(|f| !f.yanked && !files.is_empty()))
114            .filter_map(|(v, _)| Version::parse(v).ok())
115            .collect();
116
117        versions.sort_by(|a, b| b.cmp(a));
118        self.versions_cache.insert(name.clone(), versions);
119        self.metadata_cache.insert(name, metadata);
120    }
121
122    /// Get all packages that need metadata fetched
123    pub fn missing_packages(&self) -> Vec<String> {
124        let seen = self.seen_packages.borrow();
125        let mut missing: Vec<String> = seen
126            .iter()
127            .filter(|name| !self.metadata_cache.contains_key(*name))
128            .cloned()
129            .collect();
130
131        missing.sort();
132        missing
133    }
134
135    /// Pre-crawl the dependency graph to find all packages we need
136    /// This must be called before running pubgrub to ensure all metadata is available
137    pub fn discover_all_packages(&mut self) -> Vec<String> {
138        let mut to_process: Vec<String> = self.seen_packages.borrow().iter().cloned().collect();
139        let mut all_seen: HashSet<String> = to_process.iter().cloned().collect();
140
141        while let Some(pkg_name) = to_process.pop() {
142            let Some(metadata) = self.metadata_cache.get(&pkg_name) else {
143                continue;
144            };
145
146            // Get requires_dist from package info
147            let requires_dist = metadata.info.requires_dist.clone().unwrap_or_default();
148
149            for req_str in &requires_dist {
150                if let Ok(req) = Requirement::parse(req_str) {
151                    // Skip requirements with markers
152                    if req.marker.is_some() {
153                        continue;
154                    }
155                    // Skip extras
156                    if !req.extras.is_empty() {
157                        continue;
158                    }
159
160                    let normalized = Package::new(&req.name).name;
161                    if !all_seen.contains(&normalized) {
162                        all_seen.insert(normalized.clone());
163                        to_process.push(normalized);
164                    }
165                }
166            }
167        }
168
169        // Update seen_packages
170        *self.seen_packages.borrow_mut() = all_seen;
171
172        // Return missing packages
173        self.missing_packages()
174    }
175
176    /// Parse dependencies for a specific package version
177    fn parse_dependencies(
178        &self,
179        package: &str,
180        version: &Version,
181    ) -> Vec<(Package, Range<Version>)> {
182        let Some(metadata) = self.metadata_cache.get(package) else {
183            return vec![];
184        };
185
186        let version_str = version.to_string();
187        let Some(_files) = metadata.releases.get(&version_str) else {
188            return vec![];
189        };
190
191        // Get requires_dist from package info
192        let requires_dist = metadata.info.requires_dist.clone().unwrap_or_default();
193
194        let deps: Vec<_> = requires_dist
195            .iter()
196            .filter_map(|req_str| {
197                // Parse the requirement
198                let req = Requirement::parse(req_str).ok()?;
199
200                // Skip requirements with markers (MVP simplification)
201                // In the future, we'd evaluate markers against current environment
202                if req.marker.is_some() {
203                    return None;
204                }
205
206                // Skip extras (MVP simplification)
207                if !req.extras.is_empty() {
208                    return None;
209                }
210
211                let pkg = Package::new(&req.name);
212                let range = if let Some(ref spec) = req.specifier {
213                    VersionSpecifiers::parse(spec)
214                        .map(|s| s.to_pubgrub_range())
215                        .unwrap_or_else(|_| Range::any())
216                } else {
217                    Range::any()
218                };
219
220                Some((pkg, range))
221            })
222            .collect();
223
224        // Track all packages we've seen (interior mutability)
225        {
226            let mut seen = self.seen_packages.borrow_mut();
227            for (pkg, _) in &deps {
228                seen.insert(pkg.name.clone());
229            }
230        }
231
232        deps
233    }
234
235    /// Get versions for a package
236    fn get_versions(&self, package: &str) -> Option<&Vec<Version>> {
237        self.versions_cache.get(package)
238    }
239}
240
241impl DependencyProvider<Package, Version> for PyPIProvider {
242    fn choose_package_version<T: Borrow<Package>, U: Borrow<Range<Version>>>(
243        &self,
244        potential_packages: impl Iterator<Item = (T, U)>,
245    ) -> Result<(T, Option<Version>), Box<dyn StdError>> {
246        // Strategy: pick the package with fewest available versions (faster resolution)
247        let mut best: Option<(T, U, usize)> = None;
248
249        for (package, range) in potential_packages {
250            let pkg = package.borrow();
251
252            // Root package always has exactly one version
253            if pkg.name == ROOT_PACKAGE {
254                return Ok((package, Some(Version::new(vec![1, 0, 0]))));
255            }
256
257            // Count compatible versions
258            let count = match self.get_versions(&pkg.name) {
259                Some(versions) => versions
260                    .iter()
261                    .filter(|v| range.borrow().contains(v))
262                    .count(),
263                None => 0,
264            };
265
266            match &best {
267                None => best = Some((package, range, count)),
268                Some((_, _, best_count)) if count < *best_count => {
269                    best = Some((package, range, count));
270                }
271                _ => {}
272            }
273        }
274
275        let (package, range, _) = best.expect("at least one package");
276        let pkg = package.borrow();
277
278        // Find the highest compatible version
279        let version = self.get_versions(&pkg.name).and_then(|versions| {
280            versions
281                .iter()
282                .find(|v| range.borrow().contains(v))
283                .cloned()
284        });
285
286        Ok((package, version))
287    }
288
289    fn get_dependencies(
290        &self,
291        package: &Package,
292        version: &Version,
293    ) -> Result<Dependencies<Package, Version>, Box<dyn StdError>> {
294        // Root package returns the user's requirements
295        if package.name == ROOT_PACKAGE {
296            let deps: pubgrub::type_aliases::Map<Package, Range<Version>> =
297                self.root_deps.iter().cloned().collect();
298            return Ok(Dependencies::Known(deps));
299        }
300
301        // Get dependencies for the package version
302        let deps = self.parse_dependencies(&package.name, version);
303
304        let deps_map: pubgrub::type_aliases::Map<Package, Range<Version>> =
305            deps.into_iter().collect();
306
307        Ok(Dependencies::Known(deps_map))
308    }
309}
310
311// Implement required traits for Package to work with pubgrub
312impl Borrow<str> for Package {
313    fn borrow(&self) -> &str {
314        &self.name
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    #[test]
323    fn test_root_package() {
324        let provider = PyPIProvider::new(HashMap::new(), vec![]);
325        let root = Package::new(ROOT_PACKAGE);
326        let range = Range::any();
327
328        let (_, version) = provider
329            .choose_package_version(vec![(&root, &range)].into_iter())
330            .unwrap();
331        assert!(version.is_some());
332    }
333}