pyproject_toml/
resolution.rs

1use crate::{DependencyGroupSpecifier, DependencyGroups, ResolvedDependencies};
2use indexmap::IndexMap;
3use pep508_rs::Requirement;
4use std::fmt::Display;
5use thiserror::Error;
6
7#[derive(Debug, Error)]
8#[error(transparent)]
9pub struct ResolveError(#[from] ResolveErrorKind);
10
11#[derive(Debug, Error)]
12pub enum ResolveErrorKind {
13    #[error("Failed to find optional dependency `{name}` included by {included_by}")]
14    OptionalDependencyNotFound { name: String, included_by: Item },
15    #[error("Failed to find dependency group `{name}` included by {included_by}")]
16    DependencyGroupNotFound { name: String, included_by: Item },
17    #[error("Cycles are not supported: {0}")]
18    DependencyGroupCycle(Cycle),
19}
20
21/// A cycle in the recursion.
22#[derive(Debug)]
23pub struct Cycle(Vec<Item>);
24
25/// Display a cycle, e.g., `a -> b -> c -> a`.
26impl Display for Cycle {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        let Some((first, rest)) = self.0.split_first() else {
29            return Ok(());
30        };
31        write!(f, "{first}")?;
32        for group in rest {
33            write!(f, " -> {group}")?;
34        }
35        write!(f, " -> {first}")?;
36        Ok(())
37    }
38}
39
40/// A reference to either an optional dependency or a dependency group.
41#[derive(Debug, Clone, Eq, PartialEq)]
42pub enum Item {
43    Extra(String),
44    Group(String),
45}
46
47impl Display for Item {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        match self {
50            Item::Extra(extra) => write!(f, "extra:{extra}",),
51            Item::Group(group) => {
52                write!(f, "group:{group}")
53            }
54        }
55    }
56}
57
58pub(crate) fn resolve(
59    self_reference_name: Option<&str>,
60    optional_dependencies: Option<&IndexMap<String, Vec<Requirement>>>,
61    dependency_groups: Option<&DependencyGroups>,
62) -> Result<ResolvedDependencies, ResolveError> {
63    let mut resolved_dependencies = ResolvedDependencies::default();
64
65    // Resolve optional dependencies, which may only reference optional dependencies.
66    if let Some(optional_dependencies) = optional_dependencies {
67        for extra in optional_dependencies.keys() {
68            resolve_optional_dependency(
69                extra,
70                optional_dependencies,
71                &mut resolved_dependencies,
72                &mut Vec::new(),
73                self_reference_name,
74            )?;
75        }
76    }
77
78    // Resolve dependency groups, which may reference dependency groups and optional dependencies.
79    if let Some(dependency_groups) = dependency_groups {
80        for group in dependency_groups.keys() {
81            // It's a reference to other groups. Recurse into them
82            resolve_dependency_group(
83                group,
84                optional_dependencies.unwrap_or(&IndexMap::default()),
85                dependency_groups,
86                &mut resolved_dependencies,
87                &mut Vec::new(),
88                self_reference_name,
89            )?;
90        }
91    }
92
93    Ok(resolved_dependencies)
94}
95
96/// Resolves a single optional dependency.
97fn resolve_optional_dependency(
98    extra: &str,
99    optional_dependencies: &IndexMap<String, Vec<Requirement>>,
100    resolved: &mut ResolvedDependencies,
101    parents: &mut Vec<Item>,
102    project_name: Option<&str>,
103) -> Result<Vec<Requirement>, ResolveError> {
104    if let Some(requirements) = resolved.optional_dependencies.get(extra) {
105        return Ok(requirements.clone());
106    }
107
108    let Some(unresolved_requirements) = optional_dependencies.get(extra) else {
109        let parent = parents
110            .iter()
111            .last()
112            .expect("missing optional dependency must have parent")
113            .clone();
114        return Err(ResolveErrorKind::OptionalDependencyNotFound {
115            name: extra.to_string(),
116            included_by: parent,
117        }
118        .into());
119    };
120
121    // Check for cycles
122    let item = Item::Extra(extra.to_string());
123    if parents.contains(&item) {
124        return Err(ResolveErrorKind::DependencyGroupCycle(Cycle(parents.clone())).into());
125    }
126    parents.push(item);
127
128    // Recurse into references, and add their resolved requirements to our own requirements.
129    let mut resolved_requirements = Vec::with_capacity(unresolved_requirements.len());
130    for unresolved_requirement in unresolved_requirements.iter() {
131        // TODO: This should become a `PackageName` in the next breaking release.
132        if project_name
133            .is_some_and(|project_name| project_name == unresolved_requirement.name.to_string())
134        {
135            // Resolve each extra individually, as each refers to a different optional
136            // dependency entry.
137            for extra in &unresolved_requirement.extras {
138                let extra_string = extra.to_string();
139                resolved_requirements.extend(resolve_optional_dependency(
140                    &extra_string,
141                    optional_dependencies,
142                    resolved,
143                    parents,
144                    project_name,
145                )?);
146            }
147        } else {
148            resolved_requirements.push(unresolved_requirement.clone())
149        }
150    }
151    resolved
152        .optional_dependencies
153        .insert(extra.to_string(), resolved_requirements.clone());
154    parents.pop();
155    Ok(resolved_requirements)
156}
157
158/// Resolves a single dependency group.
159fn resolve_dependency_group(
160    dep_group: &String,
161    optional_dependencies: &IndexMap<String, Vec<Requirement>>,
162    dependency_groups: &DependencyGroups,
163    resolved: &mut ResolvedDependencies,
164    parents: &mut Vec<Item>,
165    project_name: Option<&str>,
166) -> Result<Vec<Requirement>, ResolveError> {
167    if let Some(requirements) = resolved.dependency_groups.get(dep_group) {
168        return Ok(requirements.clone());
169    }
170
171    let Some(unresolved_requirements) = dependency_groups.get(dep_group) else {
172        let parent = parents
173            .iter()
174            .last()
175            .expect("missing optional dependency must have parent")
176            .clone();
177        return Err(ResolveErrorKind::DependencyGroupNotFound {
178            name: dep_group.to_string(),
179            included_by: parent,
180        }
181        .into());
182    };
183
184    // Check for cycles
185    let item = Item::Group(dep_group.to_string());
186    if parents.contains(&item) {
187        return Err(ResolveErrorKind::DependencyGroupCycle(Cycle(parents.clone())).into());
188    }
189    parents.push(item);
190
191    // Otherwise, perform recursion, as required, on the dependency group's specifiers
192    let mut resolved_requirements = Vec::with_capacity(unresolved_requirements.len());
193    for unresolved_requirement in unresolved_requirements.iter() {
194        match unresolved_requirement {
195            DependencyGroupSpecifier::String(spec) => {
196                if project_name.is_some_and(|project_name| project_name == spec.name.to_string()) {
197                    for extra in &spec.extras {
198                        resolved_requirements.extend(resolve_optional_dependency(
199                            extra.as_ref(),
200                            optional_dependencies,
201                            resolved,
202                            parents,
203                            project_name,
204                        )?);
205                    }
206                } else {
207                    resolved_requirements.push(spec.clone())
208                }
209            }
210            DependencyGroupSpecifier::Table { include_group } => {
211                resolved_requirements.extend(resolve_dependency_group(
212                    include_group,
213                    optional_dependencies,
214                    dependency_groups,
215                    resolved,
216                    parents,
217                    project_name,
218                )?);
219            }
220        }
221    }
222    // Add the resolved group to IndexMap
223    resolved
224        .dependency_groups
225        .insert(dep_group.to_string(), resolved_requirements.clone());
226    parents.pop();
227    Ok(resolved_requirements)
228}
229
230#[cfg(test)]
231mod tests {
232    use pep508_rs::Requirement;
233    use std::str::FromStr;
234
235    use crate::resolution::{resolve_optional_dependency, Item};
236    use crate::{PyProjectToml, ResolvedDependencies};
237
238    #[test]
239    fn parse_pyproject_toml_optional_dependencies_resolve() {
240        let source = r#"[project]
241            name = "spam"
242
243            [project.optional-dependencies]
244            alpha = ["beta", "gamma", "delta"]
245            epsilon = ["eta<2.0", "theta==2024.09.01"]
246            iota = ["spam[alpha]"]
247        "#;
248        let pyproject_toml = PyProjectToml::new(source).unwrap();
249        let resolved_dependencies = pyproject_toml.resolve().unwrap();
250
251        assert_eq!(
252            resolved_dependencies.optional_dependencies["iota"],
253            vec![
254                Requirement::from_str("beta").unwrap(),
255                Requirement::from_str("gamma").unwrap(),
256                Requirement::from_str("delta").unwrap()
257            ]
258        );
259    }
260
261    #[test]
262    fn parse_pyproject_toml_optional_dependencies_cycle() {
263        let source = r#"
264            [project]
265            name = "spam"
266
267            [project.optional-dependencies]
268            alpha = ["spam[iota]"]
269            iota = ["spam[alpha]"]
270        "#;
271        let pyproject_toml = PyProjectToml::new(source).unwrap();
272        assert_eq!(
273            pyproject_toml.resolve().unwrap_err().to_string(),
274            "Cycles are not supported: extra:alpha -> extra:iota -> extra:alpha"
275        )
276    }
277
278    #[test]
279    fn parse_pyproject_toml_optional_dependencies_missing_include() {
280        let source = r#"
281            [project]
282            name = "spam"
283
284            [project.optional-dependencies]
285            iota = ["spam[alpha]"]
286        "#;
287        let pyproject_toml = PyProjectToml::new(source).unwrap();
288        assert_eq!(
289            pyproject_toml.resolve().unwrap_err().to_string(),
290            "Failed to find optional dependency `alpha` included by extra:iota"
291        )
292    }
293
294    #[test]
295    fn parse_pyproject_toml_optional_dependencies_missing_top_level() {
296        let source = r#"
297            [project]
298            name = "spam"
299
300            [project.optional-dependencies]
301            alpha = ["beta"]
302        "#;
303        let pyproject_toml = PyProjectToml::new(source).unwrap();
304        let mut resolved = ResolvedDependencies::default();
305        let err = resolve_optional_dependency(
306            "foo",
307            pyproject_toml
308                .project
309                .as_ref()
310                .unwrap()
311                .optional_dependencies
312                .as_ref()
313                .unwrap(),
314            &mut resolved,
315            &mut vec![Item::Extra("bar".to_string())],
316            Some("spam"),
317        )
318        .unwrap_err();
319        assert_eq!(
320            err.to_string(),
321            "Failed to find optional dependency `foo` included by extra:bar"
322        );
323    }
324
325    #[test]
326    fn parse_pyproject_toml_dependency_groups_resolve() {
327        let source = r#"
328            [dependency-groups]
329            alpha = ["beta", "gamma", "delta"]
330            epsilon = ["eta<2.0", "theta==2024.09.01"]
331            iota = [{include-group = "alpha"}]
332        "#;
333        let pyproject_toml = PyProjectToml::new(source).unwrap();
334        let resolved_dependencies = pyproject_toml.resolve().unwrap();
335
336        assert_eq!(
337            resolved_dependencies.dependency_groups["iota"],
338            vec![
339                Requirement::from_str("beta").unwrap(),
340                Requirement::from_str("gamma").unwrap(),
341                Requirement::from_str("delta").unwrap()
342            ]
343        );
344    }
345
346    #[test]
347    fn parse_pyproject_toml_dependency_groups_cycle() {
348        let source = r#"
349            [dependency-groups]
350            alpha = [{include-group = "iota"}]
351            iota = [{include-group = "alpha"}]
352        "#;
353        let pyproject_toml = PyProjectToml::new(source).unwrap();
354        assert_eq!(
355            pyproject_toml.resolve().unwrap_err().to_string(),
356            "Cycles are not supported: group:alpha -> group:iota -> group:alpha"
357        )
358    }
359
360    #[test]
361    fn parse_pyproject_toml_dependency_groups_missing_include() {
362        let source = r#"
363            [dependency-groups]
364            iota = [{include-group = "alpha"}]
365        "#;
366        let pyproject_toml = PyProjectToml::new(source).unwrap();
367        assert_eq!(
368            pyproject_toml.resolve().unwrap_err().to_string(),
369            "Failed to find dependency group `alpha` included by group:iota"
370        )
371    }
372
373    #[test]
374    fn parse_pyproject_toml_dependency_groups_with_optional_dependencies() {
375        let source = r#"
376            [project]
377            name = "spam"
378
379            [project.optional-dependencies]
380            test = ["pytest"]
381
382            [dependency-groups]
383            dev = ["spam[test]"]
384        "#;
385        let pyproject_toml = PyProjectToml::new(source).unwrap();
386        let resolved_dependencies = pyproject_toml.resolve().unwrap();
387        assert_eq!(
388            resolved_dependencies.dependency_groups["dev"],
389            vec![Requirement::from_str("pytest").unwrap()]
390        );
391    }
392
393    #[test]
394    fn name_collision() {
395        let source = r#"
396            [project]
397            name = "spam"
398
399            [project.optional-dependencies]
400            dev = ["pytest"]
401
402            [dependency-groups]
403            dev = ["ruff"]
404        "#;
405        let pyproject_toml = PyProjectToml::new(source).unwrap();
406        let resolved_dependencies = pyproject_toml.resolve().unwrap();
407        assert_eq!(
408            resolved_dependencies.optional_dependencies["dev"],
409            vec![Requirement::from_str("pytest").unwrap()]
410        );
411        assert_eq!(
412            resolved_dependencies.dependency_groups["dev"],
413            vec![Requirement::from_str("ruff").unwrap()]
414        );
415    }
416
417    #[test]
418    fn optional_dependencies_are_not_dependency_groups() {
419        let source = r#"
420            [project]
421            name = "spam"
422
423            [project.optional-dependencies]
424            test = ["pytest"]
425
426            [dependency-groups]
427            dev = ["spam[test]"]
428        "#;
429        let pyproject_toml = PyProjectToml::new(source).unwrap();
430        let resolved_dependencies = pyproject_toml.resolve().unwrap();
431        assert!(resolved_dependencies
432            .optional_dependencies
433            .contains_key("test"));
434        assert!(!resolved_dependencies.dependency_groups.contains_key("test"));
435        assert!(resolved_dependencies.dependency_groups.contains_key("dev"));
436    }
437
438    #[test]
439    fn mixed_resolution() {
440        let source = r#"
441            [project]
442            name = "spam"
443
444            [project.optional-dependencies]
445            test = ["pytest"]
446            numpy = ["numpy"]
447
448            [dependency-groups]
449            dev = ["spam[test]"]
450            test = ["spam[numpy]"]
451        "#;
452        let pyproject_toml = PyProjectToml::new(source).unwrap();
453        let resolved_dependencies = pyproject_toml.resolve().unwrap();
454        assert_eq!(
455            resolved_dependencies.dependency_groups["dev"],
456            vec![Requirement::from_str("pytest").unwrap()]
457        );
458        assert_eq!(
459            resolved_dependencies.dependency_groups["test"],
460            vec![Requirement::from_str("numpy").unwrap()]
461        );
462    }
463}