1use crate::{DependencyGroupSpecifier, DependencyGroups, ResolvedDependencies};
2use indexmap::IndexMap;
3use pep508_rs::Requirement;
4use std::fmt::Display;
5use thiserror::Error;
6
7#[derive(Debug, Error)]
8#[error(transparent)]
9pub struct ResolveError(#[from] ResolveErrorKind);
10
11#[derive(Debug, Error)]
12pub enum ResolveErrorKind {
13 #[error("Failed to find optional dependency `{name}` included by {included_by}")]
14 OptionalDependencyNotFound { name: String, included_by: Item },
15 #[error("Failed to find dependency group `{name}` included by {included_by}")]
16 DependencyGroupNotFound { name: String, included_by: Item },
17 #[error("Cycles are not supported: {0}")]
18 DependencyGroupCycle(Cycle),
19}
20
21#[derive(Debug)]
23pub struct Cycle(Vec<Item>);
24
25impl Display for Cycle {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 let Some((first, rest)) = self.0.split_first() else {
29 return Ok(());
30 };
31 write!(f, "{first}")?;
32 for group in rest {
33 write!(f, " -> {group}")?;
34 }
35 write!(f, " -> {first}")?;
36 Ok(())
37 }
38}
39
40#[derive(Debug, Clone, Eq, PartialEq)]
42pub enum Item {
43 Extra(String),
44 Group(String),
45}
46
47impl Display for Item {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 match self {
50 Item::Extra(extra) => write!(f, "extra:{extra}",),
51 Item::Group(group) => {
52 write!(f, "group:{group}")
53 }
54 }
55 }
56}
57
58pub(crate) fn resolve(
59 self_reference_name: Option<&str>,
60 optional_dependencies: Option<&IndexMap<String, Vec<Requirement>>>,
61 dependency_groups: Option<&DependencyGroups>,
62) -> Result<ResolvedDependencies, ResolveError> {
63 let mut resolved_dependencies = ResolvedDependencies::default();
64
65 if let Some(optional_dependencies) = optional_dependencies {
67 for extra in optional_dependencies.keys() {
68 resolve_optional_dependency(
69 extra,
70 optional_dependencies,
71 &mut resolved_dependencies,
72 &mut Vec::new(),
73 self_reference_name,
74 )?;
75 }
76 }
77
78 if let Some(dependency_groups) = dependency_groups {
80 for group in dependency_groups.keys() {
81 resolve_dependency_group(
83 group,
84 optional_dependencies.unwrap_or(&IndexMap::default()),
85 dependency_groups,
86 &mut resolved_dependencies,
87 &mut Vec::new(),
88 self_reference_name,
89 )?;
90 }
91 }
92
93 Ok(resolved_dependencies)
94}
95
96fn resolve_optional_dependency(
98 extra: &str,
99 optional_dependencies: &IndexMap<String, Vec<Requirement>>,
100 resolved: &mut ResolvedDependencies,
101 parents: &mut Vec<Item>,
102 project_name: Option<&str>,
103) -> Result<Vec<Requirement>, ResolveError> {
104 if let Some(requirements) = resolved.optional_dependencies.get(extra) {
105 return Ok(requirements.clone());
106 }
107
108 let Some(unresolved_requirements) = optional_dependencies.get(extra) else {
109 let parent = parents
110 .iter()
111 .last()
112 .expect("missing optional dependency must have parent")
113 .clone();
114 return Err(ResolveErrorKind::OptionalDependencyNotFound {
115 name: extra.to_string(),
116 included_by: parent,
117 }
118 .into());
119 };
120
121 let item = Item::Extra(extra.to_string());
123 if parents.contains(&item) {
124 return Err(ResolveErrorKind::DependencyGroupCycle(Cycle(parents.clone())).into());
125 }
126 parents.push(item);
127
128 let mut resolved_requirements = Vec::with_capacity(unresolved_requirements.len());
130 for unresolved_requirement in unresolved_requirements.iter() {
131 if project_name
133 .is_some_and(|project_name| project_name == unresolved_requirement.name.to_string())
134 {
135 for extra in &unresolved_requirement.extras {
138 let extra_string = extra.to_string();
139 resolved_requirements.extend(resolve_optional_dependency(
140 &extra_string,
141 optional_dependencies,
142 resolved,
143 parents,
144 project_name,
145 )?);
146 }
147 } else {
148 resolved_requirements.push(unresolved_requirement.clone())
149 }
150 }
151 resolved
152 .optional_dependencies
153 .insert(extra.to_string(), resolved_requirements.clone());
154 parents.pop();
155 Ok(resolved_requirements)
156}
157
158fn resolve_dependency_group(
160 dep_group: &String,
161 optional_dependencies: &IndexMap<String, Vec<Requirement>>,
162 dependency_groups: &DependencyGroups,
163 resolved: &mut ResolvedDependencies,
164 parents: &mut Vec<Item>,
165 project_name: Option<&str>,
166) -> Result<Vec<Requirement>, ResolveError> {
167 if let Some(requirements) = resolved.dependency_groups.get(dep_group) {
168 return Ok(requirements.clone());
169 }
170
171 let Some(unresolved_requirements) = dependency_groups.get(dep_group) else {
172 let parent = parents
173 .iter()
174 .last()
175 .expect("missing optional dependency must have parent")
176 .clone();
177 return Err(ResolveErrorKind::DependencyGroupNotFound {
178 name: dep_group.to_string(),
179 included_by: parent,
180 }
181 .into());
182 };
183
184 let item = Item::Group(dep_group.to_string());
186 if parents.contains(&item) {
187 return Err(ResolveErrorKind::DependencyGroupCycle(Cycle(parents.clone())).into());
188 }
189 parents.push(item);
190
191 let mut resolved_requirements = Vec::with_capacity(unresolved_requirements.len());
193 for unresolved_requirement in unresolved_requirements.iter() {
194 match unresolved_requirement {
195 DependencyGroupSpecifier::String(spec) => {
196 if project_name.is_some_and(|project_name| project_name == spec.name.to_string()) {
197 for extra in &spec.extras {
198 resolved_requirements.extend(resolve_optional_dependency(
199 extra.as_ref(),
200 optional_dependencies,
201 resolved,
202 parents,
203 project_name,
204 )?);
205 }
206 } else {
207 resolved_requirements.push(spec.clone())
208 }
209 }
210 DependencyGroupSpecifier::Table { include_group } => {
211 resolved_requirements.extend(resolve_dependency_group(
212 include_group,
213 optional_dependencies,
214 dependency_groups,
215 resolved,
216 parents,
217 project_name,
218 )?);
219 }
220 }
221 }
222 resolved
224 .dependency_groups
225 .insert(dep_group.to_string(), resolved_requirements.clone());
226 parents.pop();
227 Ok(resolved_requirements)
228}
229
230#[cfg(test)]
231mod tests {
232 use pep508_rs::Requirement;
233 use std::str::FromStr;
234
235 use crate::resolution::{resolve_optional_dependency, Item};
236 use crate::{PyProjectToml, ResolvedDependencies};
237
238 #[test]
239 fn parse_pyproject_toml_optional_dependencies_resolve() {
240 let source = r#"[project]
241 name = "spam"
242
243 [project.optional-dependencies]
244 alpha = ["beta", "gamma", "delta"]
245 epsilon = ["eta<2.0", "theta==2024.09.01"]
246 iota = ["spam[alpha]"]
247 "#;
248 let pyproject_toml = PyProjectToml::new(source).unwrap();
249 let resolved_dependencies = pyproject_toml.resolve().unwrap();
250
251 assert_eq!(
252 resolved_dependencies.optional_dependencies["iota"],
253 vec![
254 Requirement::from_str("beta").unwrap(),
255 Requirement::from_str("gamma").unwrap(),
256 Requirement::from_str("delta").unwrap()
257 ]
258 );
259 }
260
261 #[test]
262 fn parse_pyproject_toml_optional_dependencies_cycle() {
263 let source = r#"
264 [project]
265 name = "spam"
266
267 [project.optional-dependencies]
268 alpha = ["spam[iota]"]
269 iota = ["spam[alpha]"]
270 "#;
271 let pyproject_toml = PyProjectToml::new(source).unwrap();
272 assert_eq!(
273 pyproject_toml.resolve().unwrap_err().to_string(),
274 "Cycles are not supported: extra:alpha -> extra:iota -> extra:alpha"
275 )
276 }
277
278 #[test]
279 fn parse_pyproject_toml_optional_dependencies_missing_include() {
280 let source = r#"
281 [project]
282 name = "spam"
283
284 [project.optional-dependencies]
285 iota = ["spam[alpha]"]
286 "#;
287 let pyproject_toml = PyProjectToml::new(source).unwrap();
288 assert_eq!(
289 pyproject_toml.resolve().unwrap_err().to_string(),
290 "Failed to find optional dependency `alpha` included by extra:iota"
291 )
292 }
293
294 #[test]
295 fn parse_pyproject_toml_optional_dependencies_missing_top_level() {
296 let source = r#"
297 [project]
298 name = "spam"
299
300 [project.optional-dependencies]
301 alpha = ["beta"]
302 "#;
303 let pyproject_toml = PyProjectToml::new(source).unwrap();
304 let mut resolved = ResolvedDependencies::default();
305 let err = resolve_optional_dependency(
306 "foo",
307 pyproject_toml
308 .project
309 .as_ref()
310 .unwrap()
311 .optional_dependencies
312 .as_ref()
313 .unwrap(),
314 &mut resolved,
315 &mut vec![Item::Extra("bar".to_string())],
316 Some("spam"),
317 )
318 .unwrap_err();
319 assert_eq!(
320 err.to_string(),
321 "Failed to find optional dependency `foo` included by extra:bar"
322 );
323 }
324
325 #[test]
326 fn parse_pyproject_toml_dependency_groups_resolve() {
327 let source = r#"
328 [dependency-groups]
329 alpha = ["beta", "gamma", "delta"]
330 epsilon = ["eta<2.0", "theta==2024.09.01"]
331 iota = [{include-group = "alpha"}]
332 "#;
333 let pyproject_toml = PyProjectToml::new(source).unwrap();
334 let resolved_dependencies = pyproject_toml.resolve().unwrap();
335
336 assert_eq!(
337 resolved_dependencies.dependency_groups["iota"],
338 vec![
339 Requirement::from_str("beta").unwrap(),
340 Requirement::from_str("gamma").unwrap(),
341 Requirement::from_str("delta").unwrap()
342 ]
343 );
344 }
345
346 #[test]
347 fn parse_pyproject_toml_dependency_groups_cycle() {
348 let source = r#"
349 [dependency-groups]
350 alpha = [{include-group = "iota"}]
351 iota = [{include-group = "alpha"}]
352 "#;
353 let pyproject_toml = PyProjectToml::new(source).unwrap();
354 assert_eq!(
355 pyproject_toml.resolve().unwrap_err().to_string(),
356 "Cycles are not supported: group:alpha -> group:iota -> group:alpha"
357 )
358 }
359
360 #[test]
361 fn parse_pyproject_toml_dependency_groups_missing_include() {
362 let source = r#"
363 [dependency-groups]
364 iota = [{include-group = "alpha"}]
365 "#;
366 let pyproject_toml = PyProjectToml::new(source).unwrap();
367 assert_eq!(
368 pyproject_toml.resolve().unwrap_err().to_string(),
369 "Failed to find dependency group `alpha` included by group:iota"
370 )
371 }
372
373 #[test]
374 fn parse_pyproject_toml_dependency_groups_with_optional_dependencies() {
375 let source = r#"
376 [project]
377 name = "spam"
378
379 [project.optional-dependencies]
380 test = ["pytest"]
381
382 [dependency-groups]
383 dev = ["spam[test]"]
384 "#;
385 let pyproject_toml = PyProjectToml::new(source).unwrap();
386 let resolved_dependencies = pyproject_toml.resolve().unwrap();
387 assert_eq!(
388 resolved_dependencies.dependency_groups["dev"],
389 vec![Requirement::from_str("pytest").unwrap()]
390 );
391 }
392
393 #[test]
394 fn name_collision() {
395 let source = r#"
396 [project]
397 name = "spam"
398
399 [project.optional-dependencies]
400 dev = ["pytest"]
401
402 [dependency-groups]
403 dev = ["ruff"]
404 "#;
405 let pyproject_toml = PyProjectToml::new(source).unwrap();
406 let resolved_dependencies = pyproject_toml.resolve().unwrap();
407 assert_eq!(
408 resolved_dependencies.optional_dependencies["dev"],
409 vec![Requirement::from_str("pytest").unwrap()]
410 );
411 assert_eq!(
412 resolved_dependencies.dependency_groups["dev"],
413 vec![Requirement::from_str("ruff").unwrap()]
414 );
415 }
416
417 #[test]
418 fn optional_dependencies_are_not_dependency_groups() {
419 let source = r#"
420 [project]
421 name = "spam"
422
423 [project.optional-dependencies]
424 test = ["pytest"]
425
426 [dependency-groups]
427 dev = ["spam[test]"]
428 "#;
429 let pyproject_toml = PyProjectToml::new(source).unwrap();
430 let resolved_dependencies = pyproject_toml.resolve().unwrap();
431 assert!(resolved_dependencies
432 .optional_dependencies
433 .contains_key("test"));
434 assert!(!resolved_dependencies.dependency_groups.contains_key("test"));
435 assert!(resolved_dependencies.dependency_groups.contains_key("dev"));
436 }
437
438 #[test]
439 fn mixed_resolution() {
440 let source = r#"
441 [project]
442 name = "spam"
443
444 [project.optional-dependencies]
445 test = ["pytest"]
446 numpy = ["numpy"]
447
448 [dependency-groups]
449 dev = ["spam[test]"]
450 test = ["spam[numpy]"]
451 "#;
452 let pyproject_toml = PyProjectToml::new(source).unwrap();
453 let resolved_dependencies = pyproject_toml.resolve().unwrap();
454 assert_eq!(
455 resolved_dependencies.dependency_groups["dev"],
456 vec![Requirement::from_str("pytest").unwrap()]
457 );
458 assert_eq!(
459 resolved_dependencies.dependency_groups["test"],
460 vec![Requirement::from_str("numpy").unwrap()]
461 );
462 }
463}