pyproject_toml/
resolution.rs

1use crate::{DependencyGroupSpecifier, DependencyGroups, ResolvedDependencies};
2use indexmap::IndexMap;
3use pep508_rs::{ExtraName, Requirement};
4use std::fmt::Display;
5use std::str::FromStr;
6use thiserror::Error;
7
8/// Normalize a group/extra name according to PEP 685.
9fn normalize_name(name: &str) -> String {
10    ExtraName::from_str(name)
11        .map(|extra| extra.to_string())
12        .unwrap_or_else(|_| name.to_string())
13}
14
15#[derive(Debug, Error)]
16#[error(transparent)]
17pub struct ResolveError(#[from] ResolveErrorKind);
18
19#[derive(Debug, Error)]
20pub enum ResolveErrorKind {
21    #[error("Failed to find optional dependency `{name}` included by {included_by}")]
22    OptionalDependencyNotFound { name: String, included_by: Item },
23    #[error("Failed to find dependency group `{name}` included by {included_by}")]
24    DependencyGroupNotFound { name: String, included_by: Item },
25    #[error("Cycles are not supported: {0}")]
26    DependencyGroupCycle(Cycle),
27}
28
29/// A cycle in the recursion.
30#[derive(Debug)]
31pub struct Cycle(Vec<Item>);
32
33/// Display a cycle, e.g., `a -> b -> c -> a`.
34impl Display for Cycle {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        let Some((first, rest)) = self.0.split_first() else {
37            return Ok(());
38        };
39        write!(f, "{first}")?;
40        for group in rest {
41            write!(f, " -> {group}")?;
42        }
43        write!(f, " -> {first}")?;
44        Ok(())
45    }
46}
47
48/// A reference to either an optional dependency or a dependency group.
49#[derive(Debug, Clone, Eq, PartialEq)]
50pub enum Item {
51    Extra(String),
52    Group(String),
53}
54
55impl Display for Item {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        match self {
58            Item::Extra(extra) => write!(f, "extra:{extra}",),
59            Item::Group(group) => {
60                write!(f, "group:{group}")
61            }
62        }
63    }
64}
65
66pub(crate) fn resolve(
67    self_reference_name: Option<&str>,
68    optional_dependencies: Option<&IndexMap<String, Vec<Requirement>>>,
69    dependency_groups: Option<&DependencyGroups>,
70) -> Result<ResolvedDependencies, ResolveError> {
71    let mut resolved_dependencies = ResolvedDependencies::default();
72
73    // Resolve optional dependencies, which may only reference optional dependencies.
74    if let Some(optional_dependencies) = optional_dependencies {
75        for extra in optional_dependencies.keys() {
76            resolve_optional_dependency(
77                extra,
78                optional_dependencies,
79                &mut resolved_dependencies,
80                &mut Vec::new(),
81                self_reference_name,
82            )?;
83        }
84    }
85
86    // Resolve dependency groups, which may reference dependency groups and optional dependencies.
87    if let Some(dependency_groups) = dependency_groups {
88        for group in dependency_groups.keys() {
89            // It's a reference to other groups. Recurse into them
90            resolve_dependency_group(
91                group,
92                optional_dependencies.unwrap_or(&IndexMap::default()),
93                dependency_groups,
94                &mut resolved_dependencies,
95                &mut Vec::new(),
96                self_reference_name,
97            )?;
98        }
99    }
100
101    Ok(resolved_dependencies)
102}
103
104/// Resolves a single optional dependency.
105fn resolve_optional_dependency(
106    extra: &str,
107    optional_dependencies: &IndexMap<String, Vec<Requirement>>,
108    resolved: &mut ResolvedDependencies,
109    parents: &mut Vec<Item>,
110    project_name: Option<&str>,
111) -> Result<Vec<Requirement>, ResolveError> {
112    if let Some(requirements) = resolved.optional_dependencies.get(extra) {
113        return Ok(requirements.clone());
114    }
115
116    let normalized_extra = normalize_name(extra);
117
118    // Find the key in optional_dependencies by comparing normalized versions
119    // TODO: next breaking release remove this once Extra is added
120    let unresolved_requirements = optional_dependencies
121        .iter()
122        .find(|(key, _)| normalize_name(key) == normalized_extra)
123        .map(|(_, reqs)| reqs);
124
125    let Some(unresolved_requirements) = unresolved_requirements else {
126        let parent = parents
127            .iter()
128            .last()
129            .expect("missing optional dependency must have parent")
130            .clone();
131        return Err(ResolveErrorKind::OptionalDependencyNotFound {
132            name: extra.to_string(),
133            included_by: parent,
134        }
135        .into());
136    };
137
138    // Check for cycles
139    let item = Item::Extra(extra.to_string());
140    if parents.contains(&item) {
141        return Err(ResolveErrorKind::DependencyGroupCycle(Cycle(parents.clone())).into());
142    }
143    parents.push(item);
144
145    // Recurse into references, and add their resolved requirements to our own requirements.
146    let mut resolved_requirements = Vec::with_capacity(unresolved_requirements.len());
147    for unresolved_requirement in unresolved_requirements.iter() {
148        // TODO: This should become a `PackageName` in the next breaking release.
149        if project_name
150            .is_some_and(|project_name| project_name == unresolved_requirement.name.to_string())
151        {
152            // Resolve each extra individually, as each refers to a different optional
153            // dependency entry.
154            for extra in &unresolved_requirement.extras {
155                let extra_string = extra.to_string();
156                resolved_requirements.extend(resolve_optional_dependency(
157                    &extra_string,
158                    optional_dependencies,
159                    resolved,
160                    parents,
161                    project_name,
162                )?);
163            }
164        } else {
165            resolved_requirements.push(unresolved_requirement.clone())
166        }
167    }
168    resolved
169        .optional_dependencies
170        .insert(extra.to_string(), resolved_requirements.clone());
171    parents.pop();
172    Ok(resolved_requirements)
173}
174
175/// Resolves a single dependency group.
176fn resolve_dependency_group(
177    dep_group: &String,
178    optional_dependencies: &IndexMap<String, Vec<Requirement>>,
179    dependency_groups: &DependencyGroups,
180    resolved: &mut ResolvedDependencies,
181    parents: &mut Vec<Item>,
182    project_name: Option<&str>,
183) -> Result<Vec<Requirement>, ResolveError> {
184    if let Some(requirements) = resolved.dependency_groups.get(dep_group) {
185        return Ok(requirements.clone());
186    }
187
188    let Some(unresolved_requirements) = dependency_groups.get(dep_group) else {
189        let parent = parents
190            .iter()
191            .last()
192            .expect("missing optional dependency must have parent")
193            .clone();
194        return Err(ResolveErrorKind::DependencyGroupNotFound {
195            name: dep_group.to_string(),
196            included_by: parent,
197        }
198        .into());
199    };
200
201    // Check for cycles
202    let item = Item::Group(dep_group.to_string());
203    if parents.contains(&item) {
204        return Err(ResolveErrorKind::DependencyGroupCycle(Cycle(parents.clone())).into());
205    }
206    parents.push(item);
207
208    // Otherwise, perform recursion, as required, on the dependency group's specifiers
209    let mut resolved_requirements = Vec::with_capacity(unresolved_requirements.len());
210    for unresolved_requirement in unresolved_requirements.iter() {
211        match unresolved_requirement {
212            DependencyGroupSpecifier::String(spec) => {
213                if project_name.is_some_and(|project_name| project_name == spec.name.to_string()) {
214                    for extra in &spec.extras {
215                        resolved_requirements.extend(resolve_optional_dependency(
216                            extra.as_ref(),
217                            optional_dependencies,
218                            resolved,
219                            parents,
220                            project_name,
221                        )?);
222                    }
223                } else {
224                    resolved_requirements.push(spec.clone())
225                }
226            }
227            DependencyGroupSpecifier::Table { include_group } => {
228                resolved_requirements.extend(resolve_dependency_group(
229                    include_group,
230                    optional_dependencies,
231                    dependency_groups,
232                    resolved,
233                    parents,
234                    project_name,
235                )?);
236            }
237        }
238    }
239    // Add the resolved group to IndexMap
240    resolved
241        .dependency_groups
242        .insert(dep_group.to_string(), resolved_requirements.clone());
243    parents.pop();
244    Ok(resolved_requirements)
245}
246
247#[cfg(test)]
248mod tests {
249    use pep508_rs::Requirement;
250    use std::str::FromStr;
251
252    use crate::resolution::{resolve_optional_dependency, Item};
253    use crate::{PyProjectToml, ResolvedDependencies};
254
255    #[test]
256    fn parse_pyproject_toml_optional_dependencies_resolve() {
257        let source = r#"[project]
258            name = "spam"
259
260            [project.optional-dependencies]
261            alpha = ["beta", "gamma", "delta"]
262            epsilon = ["eta<2.0", "theta==2024.09.01"]
263            iota = ["spam[alpha]"]
264        "#;
265        let pyproject_toml = PyProjectToml::new(source).unwrap();
266        let resolved_dependencies = pyproject_toml.resolve().unwrap();
267
268        assert_eq!(
269            resolved_dependencies.optional_dependencies["iota"],
270            vec![
271                Requirement::from_str("beta").unwrap(),
272                Requirement::from_str("gamma").unwrap(),
273                Requirement::from_str("delta").unwrap()
274            ]
275        );
276    }
277
278    #[test]
279    fn parse_pyproject_toml_optional_dependencies_cycle() {
280        let source = r#"
281            [project]
282            name = "spam"
283
284            [project.optional-dependencies]
285            alpha = ["spam[iota]"]
286            iota = ["spam[alpha]"]
287        "#;
288        let pyproject_toml = PyProjectToml::new(source).unwrap();
289        assert_eq!(
290            pyproject_toml.resolve().unwrap_err().to_string(),
291            "Cycles are not supported: extra:alpha -> extra:iota -> extra:alpha"
292        )
293    }
294
295    #[test]
296    fn parse_pyproject_toml_optional_dependencies_missing_include() {
297        let source = r#"
298            [project]
299            name = "spam"
300
301            [project.optional-dependencies]
302            iota = ["spam[alpha]"]
303        "#;
304        let pyproject_toml = PyProjectToml::new(source).unwrap();
305        assert_eq!(
306            pyproject_toml.resolve().unwrap_err().to_string(),
307            "Failed to find optional dependency `alpha` included by extra:iota"
308        )
309    }
310
311    #[test]
312    fn parse_pyproject_toml_optional_dependencies_missing_top_level() {
313        let source = r#"
314            [project]
315            name = "spam"
316
317            [project.optional-dependencies]
318            alpha = ["beta"]
319        "#;
320        let pyproject_toml = PyProjectToml::new(source).unwrap();
321        let mut resolved = ResolvedDependencies::default();
322        let err = resolve_optional_dependency(
323            "foo",
324            pyproject_toml
325                .project
326                .as_ref()
327                .unwrap()
328                .optional_dependencies
329                .as_ref()
330                .unwrap(),
331            &mut resolved,
332            &mut vec![Item::Extra("bar".to_string())],
333            Some("spam"),
334        )
335        .unwrap_err();
336        assert_eq!(
337            err.to_string(),
338            "Failed to find optional dependency `foo` included by extra:bar"
339        );
340    }
341
342    #[test]
343    fn parse_pyproject_toml_dependency_groups_resolve() {
344        let source = r#"
345            [dependency-groups]
346            alpha = ["beta", "gamma", "delta"]
347            epsilon = ["eta<2.0", "theta==2024.09.01"]
348            iota = [{include-group = "alpha"}]
349        "#;
350        let pyproject_toml = PyProjectToml::new(source).unwrap();
351        let resolved_dependencies = pyproject_toml.resolve().unwrap();
352
353        assert_eq!(
354            resolved_dependencies.dependency_groups["iota"],
355            vec![
356                Requirement::from_str("beta").unwrap(),
357                Requirement::from_str("gamma").unwrap(),
358                Requirement::from_str("delta").unwrap()
359            ]
360        );
361    }
362
363    #[test]
364    fn parse_pyproject_toml_dependency_groups_cycle() {
365        let source = r#"
366            [dependency-groups]
367            alpha = [{include-group = "iota"}]
368            iota = [{include-group = "alpha"}]
369        "#;
370        let pyproject_toml = PyProjectToml::new(source).unwrap();
371        assert_eq!(
372            pyproject_toml.resolve().unwrap_err().to_string(),
373            "Cycles are not supported: group:alpha -> group:iota -> group:alpha"
374        )
375    }
376
377    #[test]
378    fn parse_pyproject_toml_dependency_groups_missing_include() {
379        let source = r#"
380            [dependency-groups]
381            iota = [{include-group = "alpha"}]
382        "#;
383        let pyproject_toml = PyProjectToml::new(source).unwrap();
384        assert_eq!(
385            pyproject_toml.resolve().unwrap_err().to_string(),
386            "Failed to find dependency group `alpha` included by group:iota"
387        )
388    }
389
390    #[test]
391    fn parse_pyproject_toml_dependency_groups_with_optional_dependencies() {
392        let source = r#"
393            [project]
394            name = "spam"
395
396            [project.optional-dependencies]
397            test = ["pytest"]
398
399            [dependency-groups]
400            dev = ["spam[test]"]
401        "#;
402        let pyproject_toml = PyProjectToml::new(source).unwrap();
403        let resolved_dependencies = pyproject_toml.resolve().unwrap();
404        assert_eq!(
405            resolved_dependencies.dependency_groups["dev"],
406            vec![Requirement::from_str("pytest").unwrap()]
407        );
408    }
409
410    #[test]
411    fn name_collision() {
412        let source = r#"
413            [project]
414            name = "spam"
415
416            [project.optional-dependencies]
417            dev = ["pytest"]
418
419            [dependency-groups]
420            dev = ["ruff"]
421        "#;
422        let pyproject_toml = PyProjectToml::new(source).unwrap();
423        let resolved_dependencies = pyproject_toml.resolve().unwrap();
424        assert_eq!(
425            resolved_dependencies.optional_dependencies["dev"],
426            vec![Requirement::from_str("pytest").unwrap()]
427        );
428        assert_eq!(
429            resolved_dependencies.dependency_groups["dev"],
430            vec![Requirement::from_str("ruff").unwrap()]
431        );
432    }
433
434    #[test]
435    fn optional_dependencies_are_not_dependency_groups() {
436        let source = r#"
437            [project]
438            name = "spam"
439
440            [project.optional-dependencies]
441            test = ["pytest"]
442
443            [dependency-groups]
444            dev = ["spam[test]"]
445        "#;
446        let pyproject_toml = PyProjectToml::new(source).unwrap();
447        let resolved_dependencies = pyproject_toml.resolve().unwrap();
448        assert!(resolved_dependencies
449            .optional_dependencies
450            .contains_key("test"));
451        assert!(!resolved_dependencies.dependency_groups.contains_key("test"));
452        assert!(resolved_dependencies.dependency_groups.contains_key("dev"));
453    }
454
455    #[test]
456    fn mixed_resolution() {
457        let source = r#"
458            [project]
459            name = "spam"
460
461            [project.optional-dependencies]
462            test = ["pytest"]
463            numpy = ["numpy"]
464
465            [dependency-groups]
466            dev = ["spam[test]"]
467            test = ["spam[numpy]"]
468        "#;
469        let pyproject_toml = PyProjectToml::new(source).unwrap();
470        let resolved_dependencies = pyproject_toml.resolve().unwrap();
471        assert_eq!(
472            resolved_dependencies.dependency_groups["dev"],
473            vec![Requirement::from_str("pytest").unwrap()]
474        );
475        assert_eq!(
476            resolved_dependencies.dependency_groups["test"],
477            vec![Requirement::from_str("numpy").unwrap()]
478        );
479    }
480
481    #[test]
482    fn optional_dependencies_with_underscores() {
483        // Test that optional dependency group names with underscores are normalized
484        // when referenced in extras. PEP 685 specifies that extras should be normalized
485        // by replacing _, ., - with a single -.
486        let source = r#"
487            [project]
488            name = "foo"
489
490            [project.optional-dependencies]
491            all = [
492              "foo[group-one]",
493              "foo[group_two]",
494            ]
495            group_one = [
496              "anyio>=4.9.0",
497            ]
498            group-two = [
499              "trio>=0.31.0",
500            ]
501        "#;
502        let pyproject_toml = PyProjectToml::new(source).unwrap();
503        let resolved_dependencies = pyproject_toml.resolve().unwrap();
504
505        // Both group-one and group_two should resolve correctly
506        assert_eq!(
507            resolved_dependencies.optional_dependencies["all"],
508            vec![
509                Requirement::from_str("anyio>=4.9.0").unwrap(),
510                Requirement::from_str("trio>=0.31.0").unwrap(),
511            ]
512        );
513    }
514}