1use crate::{DependencyGroupSpecifier, DependencyGroups, ResolvedDependencies};
2use indexmap::IndexMap;
3use pep508_rs::{ExtraName, Requirement};
4use std::fmt::Display;
5use std::str::FromStr;
6use thiserror::Error;
7
8fn normalize_name(name: &str) -> String {
10 ExtraName::from_str(name)
11 .map(|extra| extra.to_string())
12 .unwrap_or_else(|_| name.to_string())
13}
14
15#[derive(Debug, Error)]
16#[error(transparent)]
17pub struct ResolveError(#[from] ResolveErrorKind);
18
19#[derive(Debug, Error)]
20pub enum ResolveErrorKind {
21 #[error("Failed to find optional dependency `{name}` included by {included_by}")]
22 OptionalDependencyNotFound { name: String, included_by: Item },
23 #[error("Failed to find dependency group `{name}` included by {included_by}")]
24 DependencyGroupNotFound { name: String, included_by: Item },
25 #[error("Cycles are not supported: {0}")]
26 DependencyGroupCycle(Cycle),
27}
28
29#[derive(Debug)]
31pub struct Cycle(Vec<Item>);
32
33impl Display for Cycle {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 let Some((first, rest)) = self.0.split_first() else {
37 return Ok(());
38 };
39 write!(f, "{first}")?;
40 for group in rest {
41 write!(f, " -> {group}")?;
42 }
43 write!(f, " -> {first}")?;
44 Ok(())
45 }
46}
47
48#[derive(Debug, Clone, Eq, PartialEq)]
50pub enum Item {
51 Extra(String),
52 Group(String),
53}
54
55impl Display for Item {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 match self {
58 Item::Extra(extra) => write!(f, "extra:{extra}",),
59 Item::Group(group) => {
60 write!(f, "group:{group}")
61 }
62 }
63 }
64}
65
66pub(crate) fn resolve(
67 self_reference_name: Option<&str>,
68 optional_dependencies: Option<&IndexMap<String, Vec<Requirement>>>,
69 dependency_groups: Option<&DependencyGroups>,
70) -> Result<ResolvedDependencies, ResolveError> {
71 let mut resolved_dependencies = ResolvedDependencies::default();
72
73 if let Some(optional_dependencies) = optional_dependencies {
75 for extra in optional_dependencies.keys() {
76 resolve_optional_dependency(
77 extra,
78 optional_dependencies,
79 &mut resolved_dependencies,
80 &mut Vec::new(),
81 self_reference_name,
82 )?;
83 }
84 }
85
86 if let Some(dependency_groups) = dependency_groups {
88 for group in dependency_groups.keys() {
89 resolve_dependency_group(
91 group,
92 optional_dependencies.unwrap_or(&IndexMap::default()),
93 dependency_groups,
94 &mut resolved_dependencies,
95 &mut Vec::new(),
96 self_reference_name,
97 )?;
98 }
99 }
100
101 Ok(resolved_dependencies)
102}
103
104fn resolve_optional_dependency(
106 extra: &str,
107 optional_dependencies: &IndexMap<String, Vec<Requirement>>,
108 resolved: &mut ResolvedDependencies,
109 parents: &mut Vec<Item>,
110 project_name: Option<&str>,
111) -> Result<Vec<Requirement>, ResolveError> {
112 if let Some(requirements) = resolved.optional_dependencies.get(extra) {
113 return Ok(requirements.clone());
114 }
115
116 let normalized_extra = normalize_name(extra);
117
118 let unresolved_requirements = optional_dependencies
121 .iter()
122 .find(|(key, _)| normalize_name(key) == normalized_extra)
123 .map(|(_, reqs)| reqs);
124
125 let Some(unresolved_requirements) = unresolved_requirements else {
126 let parent = parents
127 .iter()
128 .last()
129 .expect("missing optional dependency must have parent")
130 .clone();
131 return Err(ResolveErrorKind::OptionalDependencyNotFound {
132 name: extra.to_string(),
133 included_by: parent,
134 }
135 .into());
136 };
137
138 let item = Item::Extra(extra.to_string());
140 if parents.contains(&item) {
141 return Err(ResolveErrorKind::DependencyGroupCycle(Cycle(parents.clone())).into());
142 }
143 parents.push(item);
144
145 let mut resolved_requirements = Vec::with_capacity(unresolved_requirements.len());
147 for unresolved_requirement in unresolved_requirements.iter() {
148 if project_name
150 .is_some_and(|project_name| project_name == unresolved_requirement.name.to_string())
151 {
152 for extra in &unresolved_requirement.extras {
155 let extra_string = extra.to_string();
156 resolved_requirements.extend(resolve_optional_dependency(
157 &extra_string,
158 optional_dependencies,
159 resolved,
160 parents,
161 project_name,
162 )?);
163 }
164 } else {
165 resolved_requirements.push(unresolved_requirement.clone())
166 }
167 }
168 resolved
169 .optional_dependencies
170 .insert(extra.to_string(), resolved_requirements.clone());
171 parents.pop();
172 Ok(resolved_requirements)
173}
174
175fn resolve_dependency_group(
177 dep_group: &String,
178 optional_dependencies: &IndexMap<String, Vec<Requirement>>,
179 dependency_groups: &DependencyGroups,
180 resolved: &mut ResolvedDependencies,
181 parents: &mut Vec<Item>,
182 project_name: Option<&str>,
183) -> Result<Vec<Requirement>, ResolveError> {
184 if let Some(requirements) = resolved.dependency_groups.get(dep_group) {
185 return Ok(requirements.clone());
186 }
187
188 let Some(unresolved_requirements) = dependency_groups.get(dep_group) else {
189 let parent = parents
190 .iter()
191 .last()
192 .expect("missing optional dependency must have parent")
193 .clone();
194 return Err(ResolveErrorKind::DependencyGroupNotFound {
195 name: dep_group.to_string(),
196 included_by: parent,
197 }
198 .into());
199 };
200
201 let item = Item::Group(dep_group.to_string());
203 if parents.contains(&item) {
204 return Err(ResolveErrorKind::DependencyGroupCycle(Cycle(parents.clone())).into());
205 }
206 parents.push(item);
207
208 let mut resolved_requirements = Vec::with_capacity(unresolved_requirements.len());
210 for unresolved_requirement in unresolved_requirements.iter() {
211 match unresolved_requirement {
212 DependencyGroupSpecifier::String(spec) => {
213 if project_name.is_some_and(|project_name| project_name == spec.name.to_string()) {
214 for extra in &spec.extras {
215 resolved_requirements.extend(resolve_optional_dependency(
216 extra.as_ref(),
217 optional_dependencies,
218 resolved,
219 parents,
220 project_name,
221 )?);
222 }
223 } else {
224 resolved_requirements.push(spec.clone())
225 }
226 }
227 DependencyGroupSpecifier::Table { include_group } => {
228 resolved_requirements.extend(resolve_dependency_group(
229 include_group,
230 optional_dependencies,
231 dependency_groups,
232 resolved,
233 parents,
234 project_name,
235 )?);
236 }
237 }
238 }
239 resolved
241 .dependency_groups
242 .insert(dep_group.to_string(), resolved_requirements.clone());
243 parents.pop();
244 Ok(resolved_requirements)
245}
246
247#[cfg(test)]
248mod tests {
249 use pep508_rs::Requirement;
250 use std::str::FromStr;
251
252 use crate::resolution::{resolve_optional_dependency, Item};
253 use crate::{PyProjectToml, ResolvedDependencies};
254
255 #[test]
256 fn parse_pyproject_toml_optional_dependencies_resolve() {
257 let source = r#"[project]
258 name = "spam"
259
260 [project.optional-dependencies]
261 alpha = ["beta", "gamma", "delta"]
262 epsilon = ["eta<2.0", "theta==2024.09.01"]
263 iota = ["spam[alpha]"]
264 "#;
265 let pyproject_toml = PyProjectToml::new(source).unwrap();
266 let resolved_dependencies = pyproject_toml.resolve().unwrap();
267
268 assert_eq!(
269 resolved_dependencies.optional_dependencies["iota"],
270 vec![
271 Requirement::from_str("beta").unwrap(),
272 Requirement::from_str("gamma").unwrap(),
273 Requirement::from_str("delta").unwrap()
274 ]
275 );
276 }
277
278 #[test]
279 fn parse_pyproject_toml_optional_dependencies_cycle() {
280 let source = r#"
281 [project]
282 name = "spam"
283
284 [project.optional-dependencies]
285 alpha = ["spam[iota]"]
286 iota = ["spam[alpha]"]
287 "#;
288 let pyproject_toml = PyProjectToml::new(source).unwrap();
289 assert_eq!(
290 pyproject_toml.resolve().unwrap_err().to_string(),
291 "Cycles are not supported: extra:alpha -> extra:iota -> extra:alpha"
292 )
293 }
294
295 #[test]
296 fn parse_pyproject_toml_optional_dependencies_missing_include() {
297 let source = r#"
298 [project]
299 name = "spam"
300
301 [project.optional-dependencies]
302 iota = ["spam[alpha]"]
303 "#;
304 let pyproject_toml = PyProjectToml::new(source).unwrap();
305 assert_eq!(
306 pyproject_toml.resolve().unwrap_err().to_string(),
307 "Failed to find optional dependency `alpha` included by extra:iota"
308 )
309 }
310
311 #[test]
312 fn parse_pyproject_toml_optional_dependencies_missing_top_level() {
313 let source = r#"
314 [project]
315 name = "spam"
316
317 [project.optional-dependencies]
318 alpha = ["beta"]
319 "#;
320 let pyproject_toml = PyProjectToml::new(source).unwrap();
321 let mut resolved = ResolvedDependencies::default();
322 let err = resolve_optional_dependency(
323 "foo",
324 pyproject_toml
325 .project
326 .as_ref()
327 .unwrap()
328 .optional_dependencies
329 .as_ref()
330 .unwrap(),
331 &mut resolved,
332 &mut vec![Item::Extra("bar".to_string())],
333 Some("spam"),
334 )
335 .unwrap_err();
336 assert_eq!(
337 err.to_string(),
338 "Failed to find optional dependency `foo` included by extra:bar"
339 );
340 }
341
342 #[test]
343 fn parse_pyproject_toml_dependency_groups_resolve() {
344 let source = r#"
345 [dependency-groups]
346 alpha = ["beta", "gamma", "delta"]
347 epsilon = ["eta<2.0", "theta==2024.09.01"]
348 iota = [{include-group = "alpha"}]
349 "#;
350 let pyproject_toml = PyProjectToml::new(source).unwrap();
351 let resolved_dependencies = pyproject_toml.resolve().unwrap();
352
353 assert_eq!(
354 resolved_dependencies.dependency_groups["iota"],
355 vec![
356 Requirement::from_str("beta").unwrap(),
357 Requirement::from_str("gamma").unwrap(),
358 Requirement::from_str("delta").unwrap()
359 ]
360 );
361 }
362
363 #[test]
364 fn parse_pyproject_toml_dependency_groups_cycle() {
365 let source = r#"
366 [dependency-groups]
367 alpha = [{include-group = "iota"}]
368 iota = [{include-group = "alpha"}]
369 "#;
370 let pyproject_toml = PyProjectToml::new(source).unwrap();
371 assert_eq!(
372 pyproject_toml.resolve().unwrap_err().to_string(),
373 "Cycles are not supported: group:alpha -> group:iota -> group:alpha"
374 )
375 }
376
377 #[test]
378 fn parse_pyproject_toml_dependency_groups_missing_include() {
379 let source = r#"
380 [dependency-groups]
381 iota = [{include-group = "alpha"}]
382 "#;
383 let pyproject_toml = PyProjectToml::new(source).unwrap();
384 assert_eq!(
385 pyproject_toml.resolve().unwrap_err().to_string(),
386 "Failed to find dependency group `alpha` included by group:iota"
387 )
388 }
389
390 #[test]
391 fn parse_pyproject_toml_dependency_groups_with_optional_dependencies() {
392 let source = r#"
393 [project]
394 name = "spam"
395
396 [project.optional-dependencies]
397 test = ["pytest"]
398
399 [dependency-groups]
400 dev = ["spam[test]"]
401 "#;
402 let pyproject_toml = PyProjectToml::new(source).unwrap();
403 let resolved_dependencies = pyproject_toml.resolve().unwrap();
404 assert_eq!(
405 resolved_dependencies.dependency_groups["dev"],
406 vec![Requirement::from_str("pytest").unwrap()]
407 );
408 }
409
410 #[test]
411 fn name_collision() {
412 let source = r#"
413 [project]
414 name = "spam"
415
416 [project.optional-dependencies]
417 dev = ["pytest"]
418
419 [dependency-groups]
420 dev = ["ruff"]
421 "#;
422 let pyproject_toml = PyProjectToml::new(source).unwrap();
423 let resolved_dependencies = pyproject_toml.resolve().unwrap();
424 assert_eq!(
425 resolved_dependencies.optional_dependencies["dev"],
426 vec![Requirement::from_str("pytest").unwrap()]
427 );
428 assert_eq!(
429 resolved_dependencies.dependency_groups["dev"],
430 vec![Requirement::from_str("ruff").unwrap()]
431 );
432 }
433
434 #[test]
435 fn optional_dependencies_are_not_dependency_groups() {
436 let source = r#"
437 [project]
438 name = "spam"
439
440 [project.optional-dependencies]
441 test = ["pytest"]
442
443 [dependency-groups]
444 dev = ["spam[test]"]
445 "#;
446 let pyproject_toml = PyProjectToml::new(source).unwrap();
447 let resolved_dependencies = pyproject_toml.resolve().unwrap();
448 assert!(resolved_dependencies
449 .optional_dependencies
450 .contains_key("test"));
451 assert!(!resolved_dependencies.dependency_groups.contains_key("test"));
452 assert!(resolved_dependencies.dependency_groups.contains_key("dev"));
453 }
454
455 #[test]
456 fn mixed_resolution() {
457 let source = r#"
458 [project]
459 name = "spam"
460
461 [project.optional-dependencies]
462 test = ["pytest"]
463 numpy = ["numpy"]
464
465 [dependency-groups]
466 dev = ["spam[test]"]
467 test = ["spam[numpy]"]
468 "#;
469 let pyproject_toml = PyProjectToml::new(source).unwrap();
470 let resolved_dependencies = pyproject_toml.resolve().unwrap();
471 assert_eq!(
472 resolved_dependencies.dependency_groups["dev"],
473 vec![Requirement::from_str("pytest").unwrap()]
474 );
475 assert_eq!(
476 resolved_dependencies.dependency_groups["test"],
477 vec![Requirement::from_str("numpy").unwrap()]
478 );
479 }
480
481 #[test]
482 fn optional_dependencies_with_underscores() {
483 let source = r#"
487 [project]
488 name = "foo"
489
490 [project.optional-dependencies]
491 all = [
492 "foo[group-one]",
493 "foo[group_two]",
494 ]
495 group_one = [
496 "anyio>=4.9.0",
497 ]
498 group-two = [
499 "trio>=0.31.0",
500 ]
501 "#;
502 let pyproject_toml = PyProjectToml::new(source).unwrap();
503 let resolved_dependencies = pyproject_toml.resolve().unwrap();
504
505 assert_eq!(
507 resolved_dependencies.optional_dependencies["all"],
508 vec![
509 Requirement::from_str("anyio>=4.9.0").unwrap(),
510 Requirement::from_str("trio>=0.31.0").unwrap(),
511 ]
512 );
513 }
514}