use std::collections::BTreeSet;
use owo_colors::OwoColorize;
use petgraph::visit::EdgeRef;
use petgraph::{Directed, Direction, Graph};
use rustc_hash::{FxBuildHasher, FxHashMap};
use uv_distribution_types::{DistributionMetadata, Name, SourceAnnotation, SourceAnnotations};
use uv_normalize::PackageName;
use uv_pep508::MarkerTree;
use crate::resolution::{RequirementsTxtDist, ResolutionGraphNode};
use crate::{ResolverEnvironment, ResolverOutput};
#[derive(Debug)]
pub struct DisplayResolutionGraph<'a> {
resolution: &'a ResolverOutput,
env: &'a ResolverEnvironment,
no_emit_packages: &'a [PackageName],
show_hashes: bool,
include_extras: bool,
include_markers: bool,
include_annotations: bool,
include_index_annotation: bool,
annotation_style: AnnotationStyle,
}
#[derive(Debug)]
enum DisplayResolutionGraphNode<'dist> {
Root,
Dist(RequirementsTxtDist<'dist>),
}
impl<'a> DisplayResolutionGraph<'a> {
#[expect(clippy::fn_params_excessive_bools)]
pub fn new(
underlying: &'a ResolverOutput,
env: &'a ResolverEnvironment,
no_emit_packages: &'a [PackageName],
show_hashes: bool,
include_extras: bool,
include_markers: bool,
include_annotations: bool,
include_index_annotation: bool,
annotation_style: AnnotationStyle,
) -> Self {
for fork_marker in &underlying.fork_markers {
assert!(
fork_marker.conflict().is_true(),
"found fork marker {fork_marker:?} with non-trivial conflicting marker, \
cannot display resolver output with conflicts in requirements.txt format",
);
}
Self {
resolution: underlying,
env,
no_emit_packages,
show_hashes,
include_extras,
include_markers,
include_annotations,
include_index_annotation,
annotation_style,
}
}
}
impl std::fmt::Display for DisplayResolutionGraph<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let sources = if self.include_annotations {
let mut sources = SourceAnnotations::default();
for requirement in self.resolution.requirements.iter().filter(|requirement| {
requirement.evaluate_markers(self.env.marker_environment(), &[])
}) {
if let Some(origin) = &requirement.origin {
sources.add(
&requirement.name,
SourceAnnotation::Requirement(origin.clone()),
);
}
}
for requirement in self
.resolution
.constraints
.requirements()
.filter(|requirement| {
requirement.evaluate_markers(self.env.marker_environment(), &[])
})
{
if let Some(origin) = &requirement.origin {
sources.add(
&requirement.name,
SourceAnnotation::Constraint(origin.clone()),
);
}
}
for requirement in self
.resolution
.overrides
.requirements()
.filter(|requirement| {
requirement.evaluate_markers(self.env.marker_environment(), &[])
})
{
if let Some(origin) = &requirement.origin {
sources.add(
&requirement.name,
SourceAnnotation::Override(origin.clone()),
);
}
}
sources
} else {
SourceAnnotations::default()
};
let graph = self.resolution.graph.map(
|_index, node| match node {
ResolutionGraphNode::Root => DisplayResolutionGraphNode::Root,
ResolutionGraphNode::Dist(dist) => {
let dist = RequirementsTxtDist::from_annotated_dist(dist);
DisplayResolutionGraphNode::Dist(dist)
}
},
|_index, _edge| (),
);
let graph = if self.include_extras {
combine_extras(&graph)
} else {
strip_extras(&graph)
};
let mut nodes = graph
.node_indices()
.filter_map(|index| {
let dist = &graph[index];
let name = dist.name();
if self.no_emit_packages.contains(name) {
return None;
}
Some((index, dist))
})
.collect::<Vec<_>>();
nodes.sort_unstable_by_key(|(index, node)| (node.to_comparator(), *index));
for (index, node) in nodes {
let mut line = node
.to_requirements_txt(&self.resolution.requires_python, self.include_markers)
.to_string();
let mut has_hashes = false;
if self.show_hashes {
for hash in node.hashes {
has_hashes = true;
line.push_str(" \\\n");
line.push_str(" --hash=");
line.push_str(&hash.to_string());
}
}
let mut annotation = None;
if self.include_annotations {
let dependents = {
let mut dependents = graph
.edges_directed(index, Direction::Incoming)
.map(|edge| &graph[edge.source()])
.map(uv_distribution_types::Name::name)
.collect::<Vec<_>>();
dependents.sort_unstable();
dependents.dedup();
dependents
};
let default = BTreeSet::default();
let source = sources.get(node.name()).unwrap_or(&default);
match self.annotation_style {
AnnotationStyle::Line => match dependents.as_slice() {
[] if source.is_empty() => {}
[] if source.len() == 1 => {
let separator = if has_hashes { "\n " } else { " " };
let comment = format!("# via {}", source.iter().next().unwrap())
.green()
.to_string();
annotation = Some((separator, comment));
}
dependents => {
let separator = if has_hashes { "\n " } else { " " };
let dependents = dependents
.iter()
.map(ToString::to_string)
.chain(source.iter().map(ToString::to_string))
.collect::<Vec<_>>()
.join(", ");
let comment = format!("# via {dependents}").green().to_string();
annotation = Some((separator, comment));
}
},
AnnotationStyle::Split => match dependents.as_slice() {
[] if source.is_empty() => {}
[] if source.len() == 1 => {
let separator = "\n";
let comment = format!(" # via {}", source.iter().next().unwrap())
.green()
.to_string();
annotation = Some((separator, comment));
}
[dependent] if source.is_empty() => {
let separator = "\n";
let comment = format!(" # via {dependent}").green().to_string();
annotation = Some((separator, comment));
}
dependents => {
let separator = "\n";
let dependent = source
.iter()
.map(ToString::to_string)
.chain(dependents.iter().map(ToString::to_string))
.map(|name| format!(" # {name}"))
.collect::<Vec<_>>()
.join("\n");
let comment = format!(" # via\n{dependent}").green().to_string();
annotation = Some((separator, comment));
}
},
}
}
if let Some((separator, comment)) = annotation {
for line in format!("{line:24}{separator}{comment}").lines() {
let line = line.trim_end();
writeln!(f, "{line}")?;
}
} else {
writeln!(f, "{line}")?;
}
if self.include_index_annotation {
if let Some(index) = node.dist.index() {
let url = index.without_credentials();
writeln!(f, "{}", format!(" # from {url}").green())?;
}
}
}
Ok(())
}
}
#[derive(Debug, Default, Copy, Clone, PartialEq, serde::Deserialize)]
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub enum AnnotationStyle {
Line,
#[default]
Split,
}
type IntermediatePetGraph<'dist> = Graph<DisplayResolutionGraphNode<'dist>, (), Directed>;
type RequirementsTxtGraph<'dist> = Graph<RequirementsTxtDist<'dist>, (), Directed>;
fn combine_extras<'dist>(graph: &IntermediatePetGraph<'dist>) -> RequirementsTxtGraph<'dist> {
fn version_marker<'dist>(dist: &'dist RequirementsTxtDist) -> (&'dist PackageName, MarkerTree) {
(dist.name(), dist.markers)
}
let mut next = RequirementsTxtGraph::with_capacity(graph.node_count(), graph.edge_count());
let mut inverse = FxHashMap::with_capacity_and_hasher(graph.node_count(), FxBuildHasher);
for index in graph.node_indices() {
let DisplayResolutionGraphNode::Dist(dist) = &graph[index] else {
continue;
};
match inverse.entry(version_marker(dist)) {
std::collections::hash_map::Entry::Occupied(entry) => {
let index = *entry.get();
let node: &mut RequirementsTxtDist = &mut next[index];
node.extras.extend(dist.extras.iter().cloned());
node.extras.sort_unstable();
node.extras.dedup();
}
std::collections::hash_map::Entry::Vacant(entry) => {
let index = next.add_node(dist.clone());
entry.insert(index);
}
}
}
for edge in graph.edge_indices() {
let (source, target) = graph.edge_endpoints(edge).unwrap();
let DisplayResolutionGraphNode::Dist(source_node) = &graph[source] else {
continue;
};
let DisplayResolutionGraphNode::Dist(target_node) = &graph[target] else {
continue;
};
let source = inverse[&version_marker(source_node)];
let target = inverse[&version_marker(target_node)];
next.update_edge(source, target, ());
}
next
}
fn strip_extras<'dist>(graph: &IntermediatePetGraph<'dist>) -> RequirementsTxtGraph<'dist> {
let mut next = RequirementsTxtGraph::with_capacity(graph.node_count(), graph.edge_count());
let mut inverse = FxHashMap::with_capacity_and_hasher(graph.node_count(), FxBuildHasher);
for index in graph.node_indices() {
let DisplayResolutionGraphNode::Dist(dist) = &graph[index] else {
continue;
};
match inverse.entry(dist.version_id()) {
std::collections::hash_map::Entry::Occupied(entry) => {
let index = *entry.get();
let node: &mut RequirementsTxtDist = &mut next[index];
node.extras.clear();
node.markers.or(dist.markers);
}
std::collections::hash_map::Entry::Vacant(entry) => {
let index = next.add_node(dist.clone());
entry.insert(index);
}
}
}
for edge in graph.edge_indices() {
let (source, target) = graph.edge_endpoints(edge).unwrap();
let DisplayResolutionGraphNode::Dist(source_node) = &graph[source] else {
continue;
};
let DisplayResolutionGraphNode::Dist(target_node) = &graph[target] else {
continue;
};
let source = inverse[&source_node.version_id()];
let target = inverse[&target_node.version_id()];
next.update_edge(source, target, ());
}
next
}