pro_core/resolver/
provider.rs1use std::borrow::Borrow;
4use std::cell::RefCell;
5use std::collections::{HashMap, HashSet};
6use std::error::Error as StdError;
7
8use pubgrub::range::Range;
9use pubgrub::solver::{Dependencies, DependencyProvider};
10
11use crate::index::{PackageMetadata, PyPIClient};
12use crate::pep::{Requirement, Version, VersionSpecifiers};
13use crate::resolver::Package;
14use crate::Error;
15
16pub const ROOT_PACKAGE: &str = "root-project";
19
20pub struct PyPIProvider {
22 metadata_cache: HashMap<String, PackageMetadata>,
24 versions_cache: HashMap<String, Vec<Version>>,
26 seen_packages: RefCell<HashSet<String>>,
28 root_deps: Vec<(Package, Range<Version>)>,
30}
31
32impl PyPIProvider {
33 pub fn new(
35 metadata: HashMap<String, PackageMetadata>,
36 root_deps: Vec<(Package, Range<Version>)>,
37 ) -> Self {
38 let mut versions_cache = HashMap::new();
39
40 for (name, meta) in &metadata {
42 let mut versions: Vec<Version> = meta
43 .releases
44 .iter()
45 .filter(|(_, files)| files.iter().any(|f| !f.yanked && !files.is_empty()))
46 .filter_map(|(v, _)| Version::parse(v).ok())
47 .collect();
48
49 versions.sort_by(|a, b| b.cmp(a));
51 versions_cache.insert(name.clone(), versions);
52 }
53
54 let mut seen = HashSet::new();
56 for (pkg, _) in &root_deps {
57 seen.insert(pkg.name.clone());
58 }
59
60 Self {
61 metadata_cache: metadata,
62 versions_cache,
63 seen_packages: RefCell::new(seen),
64 root_deps,
65 }
66 }
67
68 pub async fn build(client: &PyPIClient, requirements: &[Requirement]) -> Result<Self, Error> {
70 let names: Vec<String> = requirements.iter().map(|r| r.name.clone()).collect();
72
73 let results = client.get_packages_concurrent(&names).await;
75
76 let mut metadata = HashMap::new();
78 for (name, result) in results {
79 match result {
80 Ok(meta) => {
81 metadata.insert(name, meta);
82 }
83 Err(e) => {
84 return Err(e);
85 }
86 }
87 }
88
89 let root_deps = requirements
91 .iter()
92 .map(|r| {
93 let pkg = Package::new(&r.name);
94 let range = if let Some(ref spec) = r.specifier {
95 VersionSpecifiers::parse(spec)
96 .map(|s| s.to_pubgrub_range())
97 .unwrap_or_else(|_| Range::any())
98 } else {
99 Range::any()
100 };
101 (pkg, range)
102 })
103 .collect();
104
105 Ok(Self::new(metadata, root_deps))
106 }
107
108 pub fn add_metadata(&mut self, name: String, metadata: PackageMetadata) {
110 let mut versions: Vec<Version> = metadata
111 .releases
112 .iter()
113 .filter(|(_, files)| files.iter().any(|f| !f.yanked && !files.is_empty()))
114 .filter_map(|(v, _)| Version::parse(v).ok())
115 .collect();
116
117 versions.sort_by(|a, b| b.cmp(a));
118 self.versions_cache.insert(name.clone(), versions);
119 self.metadata_cache.insert(name, metadata);
120 }
121
122 pub fn missing_packages(&self) -> Vec<String> {
124 let seen = self.seen_packages.borrow();
125 let mut missing: Vec<String> = seen
126 .iter()
127 .filter(|name| !self.metadata_cache.contains_key(*name))
128 .cloned()
129 .collect();
130
131 missing.sort();
132 missing
133 }
134
135 pub fn discover_all_packages(&mut self) -> Vec<String> {
138 let mut to_process: Vec<String> = self.seen_packages.borrow().iter().cloned().collect();
139 let mut all_seen: HashSet<String> = to_process.iter().cloned().collect();
140
141 while let Some(pkg_name) = to_process.pop() {
142 let Some(metadata) = self.metadata_cache.get(&pkg_name) else {
143 continue;
144 };
145
146 let requires_dist = metadata.info.requires_dist.clone().unwrap_or_default();
148
149 for req_str in &requires_dist {
150 if let Ok(req) = Requirement::parse(req_str) {
151 if req.marker.is_some() {
153 continue;
154 }
155 if !req.extras.is_empty() {
157 continue;
158 }
159
160 let normalized = Package::new(&req.name).name;
161 if !all_seen.contains(&normalized) {
162 all_seen.insert(normalized.clone());
163 to_process.push(normalized);
164 }
165 }
166 }
167 }
168
169 *self.seen_packages.borrow_mut() = all_seen;
171
172 self.missing_packages()
174 }
175
176 fn parse_dependencies(
178 &self,
179 package: &str,
180 version: &Version,
181 ) -> Vec<(Package, Range<Version>)> {
182 let Some(metadata) = self.metadata_cache.get(package) else {
183 return vec![];
184 };
185
186 let version_str = version.to_string();
187 let Some(_files) = metadata.releases.get(&version_str) else {
188 return vec![];
189 };
190
191 let requires_dist = metadata.info.requires_dist.clone().unwrap_or_default();
193
194 let deps: Vec<_> = requires_dist
195 .iter()
196 .filter_map(|req_str| {
197 let req = Requirement::parse(req_str).ok()?;
199
200 if req.marker.is_some() {
203 return None;
204 }
205
206 if !req.extras.is_empty() {
208 return None;
209 }
210
211 let pkg = Package::new(&req.name);
212 let range = if let Some(ref spec) = req.specifier {
213 VersionSpecifiers::parse(spec)
214 .map(|s| s.to_pubgrub_range())
215 .unwrap_or_else(|_| Range::any())
216 } else {
217 Range::any()
218 };
219
220 Some((pkg, range))
221 })
222 .collect();
223
224 {
226 let mut seen = self.seen_packages.borrow_mut();
227 for (pkg, _) in &deps {
228 seen.insert(pkg.name.clone());
229 }
230 }
231
232 deps
233 }
234
235 fn get_versions(&self, package: &str) -> Option<&Vec<Version>> {
237 self.versions_cache.get(package)
238 }
239}
240
241impl DependencyProvider<Package, Version> for PyPIProvider {
242 fn choose_package_version<T: Borrow<Package>, U: Borrow<Range<Version>>>(
243 &self,
244 potential_packages: impl Iterator<Item = (T, U)>,
245 ) -> Result<(T, Option<Version>), Box<dyn StdError>> {
246 let mut best: Option<(T, U, usize)> = None;
248
249 for (package, range) in potential_packages {
250 let pkg = package.borrow();
251
252 if pkg.name == ROOT_PACKAGE {
254 return Ok((package, Some(Version::new(vec![1, 0, 0]))));
255 }
256
257 let count = match self.get_versions(&pkg.name) {
259 Some(versions) => versions
260 .iter()
261 .filter(|v| range.borrow().contains(v))
262 .count(),
263 None => 0,
264 };
265
266 match &best {
267 None => best = Some((package, range, count)),
268 Some((_, _, best_count)) if count < *best_count => {
269 best = Some((package, range, count));
270 }
271 _ => {}
272 }
273 }
274
275 let (package, range, _) = best.expect("at least one package");
276 let pkg = package.borrow();
277
278 let version = self.get_versions(&pkg.name).and_then(|versions| {
280 versions
281 .iter()
282 .find(|v| range.borrow().contains(v))
283 .cloned()
284 });
285
286 Ok((package, version))
287 }
288
289 fn get_dependencies(
290 &self,
291 package: &Package,
292 version: &Version,
293 ) -> Result<Dependencies<Package, Version>, Box<dyn StdError>> {
294 if package.name == ROOT_PACKAGE {
296 let deps: pubgrub::type_aliases::Map<Package, Range<Version>> =
297 self.root_deps.iter().cloned().collect();
298 return Ok(Dependencies::Known(deps));
299 }
300
301 let deps = self.parse_dependencies(&package.name, version);
303
304 let deps_map: pubgrub::type_aliases::Map<Package, Range<Version>> =
305 deps.into_iter().collect();
306
307 Ok(Dependencies::Known(deps_map))
308 }
309}
310
311impl Borrow<str> for Package {
313 fn borrow(&self) -> &str {
314 &self.name
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn test_root_package() {
324 let provider = PyPIProvider::new(HashMap::new(), vec![]);
325 let root = Package::new(ROOT_PACKAGE);
326 let range = Range::any();
327
328 let (_, version) = provider
329 .choose_package_version(vec![(&root, &range)].into_iter())
330 .unwrap();
331 assert!(version.is_some());
332 }
333}