use std::cmp::Ordering;
use std::fmt;
use std::ops::Bound;
use std::sync::Mutex;
use std::sync::MutexGuard;
use itertools::Either;
use once_cell::sync::Lazy;
use pep440_rs::{release_specifier_to_range, Operator, Version, VersionSpecifier};
use rustc_hash::FxHashMap;
use version_ranges::Ranges;
use crate::marker::MarkerValueExtra;
use crate::ExtraOperator;
use crate::{MarkerExpression, MarkerOperator, MarkerValueString, MarkerValueVersion};
pub(crate) static INTERNER: Lazy<Interner> = Lazy::new(Interner::default);
#[derive(Default)]
pub(crate) struct Interner {
pub(crate) shared: InternerShared,
state: Mutex<InternerState>,
}
#[derive(Default)]
pub(crate) struct InternerShared {
nodes: boxcar::Vec<Node>,
}
#[derive(Default)]
struct InternerState {
unique: FxHashMap<Node, NodeId>,
cache: FxHashMap<(NodeId, NodeId), NodeId>,
}
impl InternerShared {
pub(crate) fn node(&self, id: NodeId) -> &Node {
&self.nodes[id.index()]
}
}
impl Interner {
pub(crate) fn lock(&self) -> InternerGuard<'_> {
InternerGuard {
state: self.state.lock().unwrap(),
shared: &self.shared,
}
}
}
pub(crate) struct InternerGuard<'a> {
state: MutexGuard<'a, InternerState>,
shared: &'a InternerShared,
}
impl InternerGuard<'_> {
fn create_node(&mut self, var: Variable, children: Edges) -> NodeId {
let mut node = Node { var, children };
let mut first = node.children.nodes().next().unwrap();
let mut flipped = false;
if first.is_complement() {
node = node.not();
first = first.not();
flipped = true;
}
if node.children.nodes().all(|node| node == first) {
return if flipped { first.not() } else { first };
}
let id = self
.state
.unique
.entry(node.clone())
.or_insert_with(|| NodeId::new(self.shared.nodes.push(node), false));
if flipped {
id.not()
} else {
*id
}
}
pub(crate) fn expression(&mut self, expr: MarkerExpression) -> NodeId {
let (var, children) = match expr {
MarkerExpression::Version {
key: MarkerValueVersion::PythonVersion,
specifier,
} => match python_version_to_full_version(normalize_specifier(specifier)) {
Ok(specifier) => (
Variable::Version(MarkerValueVersion::PythonFullVersion),
Edges::from_specifier(specifier),
),
Err(node) => return node,
},
MarkerExpression::VersionIn {
key: MarkerValueVersion::PythonVersion,
versions,
negated,
} => match Edges::from_python_versions(versions, negated) {
Ok(edges) => (
Variable::Version(MarkerValueVersion::PythonFullVersion),
edges,
),
Err(node) => return node,
},
MarkerExpression::Version { key, specifier } => {
(Variable::Version(key), Edges::from_specifier(specifier))
}
MarkerExpression::VersionIn {
key,
versions,
negated,
} => (
Variable::Version(key),
Edges::from_versions(&versions, negated),
),
MarkerExpression::String {
key,
operator: MarkerOperator::In,
value,
} => (Variable::In { key, value }, Edges::from_bool(true)),
MarkerExpression::String {
key,
operator: MarkerOperator::NotIn,
value,
} => (Variable::In { key, value }, Edges::from_bool(false)),
MarkerExpression::String {
key,
operator: MarkerOperator::Contains,
value,
} => (Variable::Contains { key, value }, Edges::from_bool(true)),
MarkerExpression::String {
key,
operator: MarkerOperator::NotContains,
value,
} => (Variable::Contains { key, value }, Edges::from_bool(false)),
MarkerExpression::String {
key,
operator,
value,
} => (Variable::String(key), Edges::from_string(operator, value)),
MarkerExpression::Extra {
name,
operator: ExtraOperator::Equal,
} => (Variable::Extra(name), Edges::from_bool(true)),
MarkerExpression::Extra {
name,
operator: ExtraOperator::NotEqual,
} => (Variable::Extra(name), Edges::from_bool(false)),
};
self.create_node(var, children)
}
pub(crate) fn or(&mut self, x: NodeId, y: NodeId) -> NodeId {
self.and(x.not(), y.not()).not()
}
pub(crate) fn and(&mut self, xi: NodeId, yi: NodeId) -> NodeId {
if xi.is_true() {
return yi;
}
if yi.is_true() {
return xi;
}
if xi == yi {
return xi;
}
if xi.is_false() || yi.is_false() {
return NodeId::FALSE;
}
if xi.not() == yi {
return NodeId::FALSE;
}
if let Some(result) = self.state.cache.get(&(xi, yi)) {
return *result;
}
let (x, y) = (self.shared.node(xi), self.shared.node(yi));
let (func, children) = match x.var.cmp(&y.var) {
Ordering::Less => {
let children = x.children.map(xi, |node| self.and(node, yi));
(x.var.clone(), children)
}
Ordering::Greater => {
let children = y.children.map(yi, |node| self.and(node, xi));
(y.var.clone(), children)
}
Ordering::Equal => {
let children = x.children.apply(xi, &y.children, yi, |x, y| self.and(x, y));
(x.var.clone(), children)
}
};
let node = self.create_node(func, children);
self.state.cache.insert((xi, yi), node);
node
}
pub(crate) fn is_disjoint(&mut self, xi: NodeId, yi: NodeId) -> bool {
if xi.is_false() || yi.is_false() {
return true;
}
if xi.is_true() || yi.is_true() {
return false;
}
if xi == yi {
return false;
}
if xi.not() == yi {
return true;
}
let (x, y) = (self.shared.node(xi), self.shared.node(yi));
match x.var.cmp(&y.var) {
Ordering::Less => x
.children
.nodes()
.all(|x| self.is_disjoint(x.negate(xi), yi)),
Ordering::Greater => y
.children
.nodes()
.all(|y| self.is_disjoint(y.negate(yi), xi)),
Ordering::Equal => x.children.is_disjoint(xi, &y.children, yi, self),
}
}
pub(crate) fn restrict(&mut self, i: NodeId, f: &impl Fn(&Variable) -> Option<bool>) -> NodeId {
if matches!(i, NodeId::TRUE | NodeId::FALSE) {
return i;
}
let node = self.shared.node(i);
if let Edges::Boolean { high, low } = node.children {
if let Some(value) = f(&node.var) {
let node = if value { high } else { low };
return node.negate(i);
}
}
let children = node.children.map(i, |node| self.restrict(node, f));
self.create_node(node.var.clone(), children)
}
pub(crate) fn simplify_python_versions(
&mut self,
i: NodeId,
py_lower: Bound<&Version>,
py_upper: Bound<&Version>,
) -> NodeId {
if matches!(i, NodeId::TRUE | NodeId::FALSE)
|| matches!((py_lower, py_upper), (Bound::Unbounded, Bound::Unbounded))
{
return i;
}
let node = self.shared.node(i);
let Node {
var: Variable::Version(MarkerValueVersion::PythonFullVersion),
children: Edges::Version { ref edges },
} = node
else {
let children = node.children.map(i, |node_id| {
self.simplify_python_versions(node_id, py_lower, py_upper)
});
return self.create_node(node.var.clone(), children);
};
let py_range = Ranges::from_range_bounds((py_lower.cloned(), py_upper.cloned()));
if py_range.is_empty() {
return NodeId::FALSE;
}
let mut new = SmallVec::new();
for &(ref range, node) in edges {
let overlap = range.intersection(&py_range);
if overlap.is_empty() {
continue;
}
new.push((overlap.clone(), node));
}
let &(ref first_range, first_node_id) = new.first().unwrap();
let first_upper = first_range.bounding_range().unwrap().1;
let clipped = Ranges::from_range_bounds((Bound::Unbounded, first_upper.cloned()));
*new.first_mut().unwrap() = (clipped, first_node_id);
let &(ref last_range, last_node_id) = new.last().unwrap();
let last_lower = last_range.bounding_range().unwrap().0;
let clipped = Ranges::from_range_bounds((last_lower.cloned(), Bound::Unbounded));
*new.last_mut().unwrap() = (clipped, last_node_id);
self.create_node(node.var.clone(), Edges::Version { edges: new })
.negate(i)
}
pub(crate) fn complexify_python_versions(
&mut self,
i: NodeId,
py_lower: Bound<&Version>,
py_upper: Bound<&Version>,
) -> NodeId {
if matches!(i, NodeId::FALSE)
|| matches!((py_lower, py_upper), (Bound::Unbounded, Bound::Unbounded))
{
return i;
}
let py_range = Ranges::from_range_bounds((py_lower.cloned(), py_upper.cloned()));
if py_range.is_empty() {
return NodeId::FALSE;
}
if matches!(i, NodeId::TRUE) {
let var = Variable::Version(MarkerValueVersion::PythonFullVersion);
let edges = Edges::Version {
edges: Edges::from_range(&py_range),
};
return self.create_node(var, edges).negate(i);
}
let node = self.shared.node(i);
let Node {
var: Variable::Version(MarkerValueVersion::PythonFullVersion),
children: Edges::Version { ref edges },
} = node
else {
let children = node.children.map(i, |node_id| {
self.complexify_python_versions(node_id, py_lower, py_upper)
});
return self.create_node(node.var.clone(), children);
};
let mut new: SmallVec<_> = edges
.iter()
.filter(|(range, _)| !py_range.intersection(range).is_empty())
.cloned()
.collect();
assert!(
!new.is_empty(),
"expected at least one non-empty intersection"
);
let exclude_node_id = NodeId::FALSE.negate(i);
if !matches!(py_lower, Bound::Unbounded) {
let &(ref first_range, first_node_id) = new.first().unwrap();
let first_upper = first_range.bounding_range().unwrap().1;
if exclude_node_id == first_node_id {
let clipped = Ranges::from_range_bounds((Bound::Unbounded, first_upper.cloned()));
*new.first_mut().unwrap() = (clipped, first_node_id);
} else {
let clipped = Ranges::from_range_bounds((py_lower.cloned(), first_upper.cloned()));
*new.first_mut().unwrap() = (clipped, first_node_id);
let py_range_lower =
Ranges::from_range_bounds((py_lower.cloned(), Bound::Unbounded));
new.insert(0, (py_range_lower.complement(), NodeId::FALSE.negate(i)));
}
}
if !matches!(py_upper, Bound::Unbounded) {
let &(ref last_range, last_node_id) = new.last().unwrap();
let last_lower = last_range.bounding_range().unwrap().0;
if exclude_node_id == last_node_id {
let clipped = Ranges::from_range_bounds((last_lower.cloned(), Bound::Unbounded));
*new.last_mut().unwrap() = (clipped, last_node_id);
} else {
let clipped = Ranges::from_range_bounds((last_lower.cloned(), py_upper.cloned()));
*new.last_mut().unwrap() = (clipped, last_node_id);
let py_range_upper =
Ranges::from_range_bounds((Bound::Unbounded, py_upper.cloned()));
new.push((py_range_upper.complement(), exclude_node_id));
}
}
self.create_node(node.var.clone(), Edges::Version { edges: new })
.negate(i)
}
}
#[derive(PartialOrd, Ord, PartialEq, Eq, Hash, Clone, Debug)]
pub(crate) enum Variable {
Version(MarkerValueVersion),
String(MarkerValueString),
In {
key: MarkerValueString,
value: String,
},
Contains {
key: MarkerValueString,
value: String,
},
Extra(MarkerValueExtra),
}
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
pub(crate) struct Node {
pub(crate) var: Variable,
pub(crate) children: Edges,
}
impl Node {
fn not(self) -> Node {
Node {
var: self.var,
children: self.children.not(),
}
}
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub(crate) struct NodeId(usize);
impl NodeId {
pub(crate) const TRUE: NodeId = NodeId(0);
pub(crate) const FALSE: NodeId = NodeId(1);
fn new(index: usize, complement: bool) -> NodeId {
let index = (index + 1) << 1;
NodeId(index | usize::from(complement))
}
fn index(self) -> usize {
(self.0 >> 1) - 1
}
fn is_complement(self) -> bool {
(self.0 & 1) == 1
}
pub(crate) fn not(self) -> NodeId {
NodeId(self.0 ^ 1)
}
pub(crate) fn negate(self, parent: NodeId) -> NodeId {
if parent.is_complement() {
self.not()
} else {
self
}
}
pub(crate) fn is_false(self) -> bool {
self == NodeId::FALSE
}
pub(crate) fn is_true(self) -> bool {
self == NodeId::TRUE
}
}
type SmallVec<T> = smallvec::SmallVec<[T; 5]>;
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
#[allow(clippy::large_enum_variant)] pub(crate) enum Edges {
Version {
edges: SmallVec<(Ranges<Version>, NodeId)>,
},
String {
edges: SmallVec<(Ranges<String>, NodeId)>,
},
Boolean {
high: NodeId,
low: NodeId,
},
}
impl Edges {
fn from_bool(complemented: bool) -> Edges {
if complemented {
Edges::Boolean {
high: NodeId::TRUE,
low: NodeId::FALSE,
}
} else {
Edges::Boolean {
high: NodeId::FALSE,
low: NodeId::TRUE,
}
}
}
fn from_string(operator: MarkerOperator, value: String) -> Edges {
let range: Ranges<String> = match operator {
MarkerOperator::Equal => Ranges::singleton(value),
MarkerOperator::NotEqual => Ranges::singleton(value).complement(),
MarkerOperator::GreaterThan => Ranges::strictly_higher_than(value),
MarkerOperator::GreaterEqual => Ranges::higher_than(value),
MarkerOperator::LessThan => Ranges::strictly_lower_than(value),
MarkerOperator::LessEqual => Ranges::lower_than(value),
MarkerOperator::TildeEqual => unreachable!("string comparisons with ~= are ignored"),
_ => unreachable!("`in` and `contains` are treated as boolean variables"),
};
Edges::String {
edges: Edges::from_range(&range),
}
}
fn from_specifier(specifier: VersionSpecifier) -> Edges {
let specifier = release_specifier_to_range(normalize_specifier(specifier));
Edges::Version {
edges: Edges::from_range(&specifier),
}
}
fn from_python_versions(versions: Vec<Version>, negated: bool) -> Result<Edges, NodeId> {
let mut range = Ranges::empty();
for version in versions {
let specifier = VersionSpecifier::equals_version(version.clone());
let specifier = python_version_to_full_version(specifier)?;
let pubgrub_specifier = release_specifier_to_range(normalize_specifier(specifier));
range = range.union(&pubgrub_specifier);
}
if negated {
range = range.complement();
}
Ok(Edges::Version {
edges: Edges::from_range(&range),
})
}
fn from_versions(versions: &Vec<Version>, negated: bool) -> Edges {
let mut range = Ranges::empty();
for version in versions {
range = range.union(&Ranges::singleton(version.clone()));
}
if negated {
range = range.complement();
}
Edges::Version {
edges: Edges::from_range(&range),
}
}
fn from_range<T>(range: &Ranges<T>) -> SmallVec<(Ranges<T>, NodeId)>
where
T: Ord + Clone,
{
let mut edges = SmallVec::new();
for (start, end) in range.iter() {
let range = Ranges::from_range_bounds((start.clone(), end.clone()));
edges.push((range, NodeId::TRUE));
}
for (start, end) in range.complement().iter() {
let range = Ranges::from_range_bounds((start.clone(), end.clone()));
edges.push((range, NodeId::FALSE));
}
edges.sort_by(|(range1, _), (range2, _)| compare_disjoint_range_start(range1, range2));
edges
}
fn apply(
&self,
parent: NodeId,
right_edges: &Edges,
right_parent: NodeId,
mut apply: impl FnMut(NodeId, NodeId) -> NodeId,
) -> Edges {
match (self, right_edges) {
(Edges::Version { edges }, Edges::Version { edges: right_edges }) => Edges::Version {
edges: Edges::apply_ranges(edges, parent, right_edges, right_parent, apply),
},
(Edges::String { edges }, Edges::String { edges: right_edges }) => Edges::String {
edges: Edges::apply_ranges(edges, parent, right_edges, right_parent, apply),
},
(
Edges::Boolean { high, low },
Edges::Boolean {
high: right_high,
low: right_low,
},
) => Edges::Boolean {
high: apply(high.negate(parent), right_high.negate(right_parent)),
low: apply(low.negate(parent), right_low.negate(right_parent)),
},
_ => unreachable!("cannot merge two `Edges` of different types"),
}
}
fn apply_ranges<T>(
left_edges: &SmallVec<(Ranges<T>, NodeId)>,
left_parent: NodeId,
right_edges: &SmallVec<(Ranges<T>, NodeId)>,
right_parent: NodeId,
mut apply: impl FnMut(NodeId, NodeId) -> NodeId,
) -> SmallVec<(Ranges<T>, NodeId)>
where
T: Clone + Ord,
{
let mut combined = SmallVec::new();
for (left_range, left_child) in left_edges {
for (right_range, right_child) in right_edges {
let intersection = right_range.intersection(left_range);
if intersection.is_empty() {
continue;
}
let node = apply(
left_child.negate(left_parent),
right_child.negate(right_parent),
);
match combined.last_mut() {
Some((range, prev)) if *prev == node && can_conjoin(range, &intersection) => {
*range = range.union(&intersection);
}
_ => combined.push((intersection.clone(), node)),
}
}
}
combined
}
fn is_disjoint(
&self,
parent: NodeId,
right_edges: &Edges,
right_parent: NodeId,
interner: &mut InternerGuard<'_>,
) -> bool {
match (self, right_edges) {
(Edges::Version { edges }, Edges::Version { edges: right_edges }) => {
Edges::is_disjoint_ranges(edges, parent, right_edges, right_parent, interner)
}
(Edges::String { edges }, Edges::String { edges: right_edges }) => {
Edges::is_disjoint_ranges(edges, parent, right_edges, right_parent, interner)
}
(
Edges::Boolean { high, low },
Edges::Boolean {
high: right_high,
low: right_low,
},
) => {
interner.is_disjoint(high.negate(parent), right_high.negate(right_parent))
&& interner.is_disjoint(low.negate(parent), right_low.negate(right_parent))
}
_ => unreachable!("cannot merge two `Edges` of different types"),
}
}
fn is_disjoint_ranges<T>(
left_edges: &SmallVec<(Ranges<T>, NodeId)>,
left_parent: NodeId,
right_edges: &SmallVec<(Ranges<T>, NodeId)>,
right_parent: NodeId,
interner: &mut InternerGuard<'_>,
) -> bool
where
T: Clone + Ord,
{
for (left_range, left_child) in left_edges {
for (right_range, right_child) in right_edges {
let intersection = right_range.intersection(left_range);
if intersection.is_empty() {
continue;
}
if !interner.is_disjoint(
left_child.negate(left_parent),
right_child.negate(right_parent),
) {
return false;
}
}
}
true
}
fn map(&self, parent: NodeId, mut f: impl FnMut(NodeId) -> NodeId) -> Edges {
match self {
Edges::Version { edges: map } => Edges::Version {
edges: map
.iter()
.cloned()
.map(|(range, node)| (range, f(node.negate(parent))))
.collect(),
},
Edges::String { edges: map } => Edges::String {
edges: map
.iter()
.cloned()
.map(|(range, node)| (range, f(node.negate(parent))))
.collect(),
},
Edges::Boolean { high, low } => Edges::Boolean {
low: f(low.negate(parent)),
high: f(high.negate(parent)),
},
}
}
fn nodes(&self) -> impl Iterator<Item = NodeId> + '_ {
match self {
Edges::Version { edges: map } => {
Either::Left(Either::Left(map.iter().map(|(_, node)| *node)))
}
Edges::String { edges: map } => {
Either::Left(Either::Right(map.iter().map(|(_, node)| *node)))
}
Edges::Boolean { high, low } => Either::Right([*high, *low].into_iter()),
}
}
fn not(self) -> Edges {
match self {
Edges::Version { edges: map } => Edges::Version {
edges: map
.into_iter()
.map(|(range, node)| (range, node.not()))
.collect(),
},
Edges::String { edges: map } => Edges::String {
edges: map
.into_iter()
.map(|(range, node)| (range, node.not()))
.collect(),
},
Edges::Boolean { high, low } => Edges::Boolean {
high: high.not(),
low: low.not(),
},
}
}
}
fn normalize_specifier(specifier: VersionSpecifier) -> VersionSpecifier {
let (operator, version) = specifier.into_parts();
let mut release = version.release();
if !operator.is_star() {
if let Some(end) = release.iter().rposition(|segment| *segment != 0) {
if end > 0 {
release = &release[..=end];
}
}
}
VersionSpecifier::from_version(operator, Version::new(release)).unwrap()
}
fn python_version_to_full_version(specifier: VersionSpecifier) -> Result<VersionSpecifier, NodeId> {
let major_minor = match *specifier.version().release() {
[_major] if specifier.operator().is_star() => return Ok(specifier),
[major] => Some((major, 0)),
[major, minor] => Some((major, minor)),
_ => None,
};
if let Some((major, minor)) = major_minor {
let version = Version::new([major, minor]);
Ok(match specifier.operator() {
Operator::Equal | Operator::ExactEqual => {
VersionSpecifier::equals_star_version(version)
}
Operator::NotEqual => VersionSpecifier::not_equals_star_version(version),
Operator::GreaterThan => {
VersionSpecifier::greater_than_equal_version(Version::new([major, minor + 1]))
}
Operator::LessThan => specifier,
Operator::GreaterThanEqual => specifier,
Operator::LessThanEqual => {
VersionSpecifier::less_than_version(Version::new([major, minor + 1]))
}
Operator::EqualStar | Operator::NotEqualStar | Operator::TildeEqual => specifier,
})
} else {
let &[major, minor, ..] = specifier.version().release() else {
unreachable!()
};
Ok(match specifier.operator() {
Operator::Equal | Operator::ExactEqual | Operator::EqualStar | Operator::TildeEqual => {
return Err(NodeId::FALSE)
}
Operator::NotEqual | Operator::NotEqualStar => return Err(NodeId::TRUE),
Operator::LessThan | Operator::LessThanEqual => {
VersionSpecifier::less_than_version(Version::new([major, minor + 1]))
}
Operator::GreaterThan | Operator::GreaterThanEqual => {
VersionSpecifier::greater_than_equal_version(Version::new([major, minor + 1]))
}
})
}
}
fn compare_disjoint_range_start<T>(range1: &Ranges<T>, range2: &Ranges<T>) -> Ordering
where
T: Ord,
{
let (upper1, _) = range1.bounding_range().unwrap();
let (upper2, _) = range2.bounding_range().unwrap();
match (upper1, upper2) {
(Bound::Unbounded, _) => Ordering::Less,
(_, Bound::Unbounded) => Ordering::Greater,
(Bound::Included(v1), Bound::Excluded(v2)) if v1 == v2 => Ordering::Less,
(Bound::Excluded(v1), Bound::Included(v2)) if v1 == v2 => Ordering::Greater,
(Bound::Included(v1) | Bound::Excluded(v1), Bound::Included(v2) | Bound::Excluded(v2)) => {
v1.cmp(v2)
}
}
}
fn can_conjoin<T>(range1: &Ranges<T>, range2: &Ranges<T>) -> bool
where
T: Ord + Clone,
{
let Some((_, end)) = range1.bounding_range() else {
return false;
};
let Some((start, _)) = range2.bounding_range() else {
return false;
};
match (end, start) {
(Bound::Included(v1), Bound::Excluded(v2)) if v1 == v2 => true,
(Bound::Excluded(v1), Bound::Included(v2)) if v1 == v2 => true,
_ => false,
}
}
impl fmt::Debug for NodeId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.is_false() {
return write!(f, "false");
}
if self.is_true() {
return write!(f, "true");
}
if self.is_complement() {
write!(f, "{:?}", INTERNER.shared.node(*self).clone().not())
} else {
write!(f, "{:?}", INTERNER.shared.node(*self))
}
}
}
#[cfg(test)]
mod tests;