use std::collections::BTreeMap;
use std::ops::Bound;
use ahash::HashMap;
use crate::{
DependencyProvider, Requirement, VariableId, VersionSetId,
internal::{arena::Arena, id::ClauseId, small_vec::SmallVec},
requirement::RequirementMap,
solver_id::{IdMap, IdSet, SolverId},
};
use super::{
conditions::{Disjunction, DisjunctionId},
decision::Decision,
decision_map::DecisionMap,
};
type TrackedClauseId = u32;
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
#[repr(transparent)]
struct ClausePosition(u64);
impl ClausePosition {
fn new(parent_pos: usize, clause_pos: usize) -> Self {
assert!(
parent_pos < u32::MAX as usize && clause_pos < u32::MAX as usize,
"clause position exceeds the packed u64 layout"
);
Self((parent_pos as u64) << 32 | clause_pos as u64)
}
}
#[derive(Copy, Clone)]
struct TrackedClause {
position: ClausePosition,
parent: VariableId,
requirement: Requirement,
condition: Option<DisjunctionId>,
clause_id: ClauseId,
hot: bool,
}
#[derive(Copy, Clone)]
enum RequirementState {
Dirty,
Satisfied { by: VariableId },
Frontier {
candidate: VariableId,
version_set: VersionSetId,
count: u32,
},
}
struct RequirementEntry {
state: RequirementState,
occurrences_registered: bool,
}
pub(crate) struct QueueDecision {
pub candidate: VariableId,
pub required_by: VariableId,
pub clause_id: ClauseId,
pub package_activity: f32,
pub candidate_count: u32,
}
#[cfg(feature = "diagnostics")]
#[derive(Default)]
pub(crate) struct DecideQueueCounters {
pub sync_touches: u64,
pub dequeues: u64,
pub selection_visits: u64,
pub hot_visits: u64,
pub walk_evals: u64,
}
pub(crate) struct DecideQueue<D: DependencyProvider> {
clauses: Vec<TrackedClause>,
parent_positions: HashMap<VariableId, u32>,
clauses_by_parent: Vec<Vec<TrackedClauseId>>,
queue: BTreeMap<ClausePosition, TrackedClauseId>,
hot_queue: BTreeMap<ClausePosition, TrackedClauseId>,
hot_names: <D::NameId as SolverId>::Set,
clauses_by_name: HashMap<D::NameId, Vec<TrackedClauseId>>,
requirements_by_candidate: HashMap<VariableId, SmallVec<Requirement>>,
clauses_by_condition_variable: HashMap<VariableId, Vec<TrackedClauseId>>,
requirement_states: RequirementMap<RequirementEntry>,
clauses_by_requirement: RequirementMap<Vec<TrackedClauseId>>,
mirror: Vec<VariableId>,
hot_only: bool,
#[cfg(feature = "diagnostics")]
pub(crate) counters: DecideQueueCounters,
}
impl<D: DependencyProvider> Default for DecideQueue<D> {
fn default() -> Self {
Self {
clauses: Vec::new(),
parent_positions: HashMap::default(),
clauses_by_parent: Vec::new(),
queue: BTreeMap::new(),
hot_queue: BTreeMap::new(),
hot_names: Default::default(),
clauses_by_name: HashMap::default(),
requirements_by_candidate: HashMap::default(),
clauses_by_condition_variable: HashMap::default(),
requirement_states: RequirementMap::default(),
clauses_by_requirement: RequirementMap::default(),
mirror: Vec::new(),
hot_only: true,
#[cfg(feature = "diagnostics")]
counters: DecideQueueCounters::default(),
}
}
}
fn enqueue_clause(
queue: &mut BTreeMap<ClausePosition, TrackedClauseId>,
hot_queue: &mut BTreeMap<ClausePosition, TrackedClauseId>,
clauses: &[TrackedClause],
map: &DecisionMap,
id: TrackedClauseId,
) {
let clause = &clauses[id as usize];
if map.value(clause.parent) != Some(true) {
return;
}
queue.insert(clause.position, id);
if clause.hot {
hot_queue.insert(clause.position, id);
}
}
impl<D: DependencyProvider> DecideQueue<D> {
pub(crate) fn set_standard_activity_params(&mut self, standard: bool) {
self.hot_only = standard;
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn register_clause(
&mut self,
parent: VariableId,
requirement: Requirement,
condition: Option<DisjunctionId>,
clause_id: ClauseId,
names: impl IntoIterator<Item = D::NameId>,
disjunctions: &Arena<DisjunctionId, Disjunction>,
parent_value: Option<bool>,
) {
let parent_pos = *self.parent_positions.entry(parent).or_insert_with(|| {
self.clauses_by_parent.push(Vec::new());
(self.clauses_by_parent.len() - 1) as u32
}) as usize;
let clause_pos = self.clauses_by_parent[parent_pos].len();
let position = ClausePosition::new(parent_pos, clause_pos);
let id = self.clauses.len() as TrackedClauseId;
let mut hot = false;
let mut seen_names: SmallVec<D::NameId> = SmallVec::empty();
for name in names {
if seen_names.as_slice().contains(&name) {
continue;
}
seen_names.push(name);
if self.hot_names.contains(name) {
hot = true;
}
self.clauses_by_name.entry(name).or_default().push(id);
}
if let Some(condition) = condition {
for literal in &disjunctions[condition].literals {
self.clauses_by_condition_variable
.entry(literal.variable())
.or_default()
.push(id);
}
}
self.requirement_states
.get_or_insert_with(requirement, || RequirementEntry {
state: RequirementState::Dirty,
occurrences_registered: false,
});
self.clauses_by_requirement
.get_or_insert_with(requirement, Vec::new)
.push(id);
self.clauses_by_parent[parent_pos].push(id);
self.clauses.push(TrackedClause {
position,
parent,
requirement,
condition,
clause_id,
hot,
});
if parent_value == Some(true) {
self.queue.insert(position, id);
if hot {
self.hot_queue.insert(position, id);
}
}
}
pub(crate) fn mark_name_hot(&mut self, name: D::NameId) {
if !self.hot_names.insert(name) {
return;
}
let Some(ids) = self.clauses_by_name.get(&name) else {
return;
};
for &id in ids {
let clause = &mut self.clauses[id as usize];
if clause.hot {
continue;
}
clause.hot = true;
let position = clause.position;
if let Some(&queued) = self.queue.get(&position) {
self.hot_queue.insert(position, queued);
}
}
}
pub(crate) fn sync(&mut self, floor: usize, trail: &[Decision], map: &DecisionMap) {
for i in floor..self.mirror.len() {
let variable = self.mirror[i];
self.route_touched(variable, map);
}
self.mirror.truncate(floor);
for decision in &trail[floor..] {
self.mirror.push(decision.variable);
self.route_touched(decision.variable, map);
}
}
fn route_touched(&mut self, variable: VariableId, map: &DecisionMap) {
#[cfg(feature = "diagnostics")]
{
self.counters.sync_touches += 1;
}
let Self {
clauses,
parent_positions,
clauses_by_parent,
queue,
hot_queue,
requirements_by_candidate,
clauses_by_condition_variable,
requirement_states,
clauses_by_requirement,
..
} = self;
let value = map.value(variable);
if value == Some(true) {
if let Some(&parent_pos) = parent_positions.get(&variable) {
for &id in &clauses_by_parent[parent_pos as usize] {
enqueue_clause(queue, hot_queue, clauses, map, id);
}
}
}
if let Some(requirements) = requirements_by_candidate.get(&variable) {
for &requirement in requirements.as_slice() {
let entry = requirement_states
.get_mut(requirement)
.expect("occurrence-registered requirement has a cache entry");
match entry.state {
RequirementState::Dirty => {}
RequirementState::Frontier { .. } => entry.state = RequirementState::Dirty,
RequirementState::Satisfied { by } => {
if map.value(by) == Some(true) {
continue;
}
entry.state = RequirementState::Dirty;
if let Some(woken) = clauses_by_requirement.get(requirement) {
for &id in woken {
enqueue_clause(queue, hot_queue, clauses, map, id);
}
}
}
}
}
}
if let Some(woken) = clauses_by_condition_variable.get(&variable) {
for &id in woken {
enqueue_clause(queue, hot_queue, clauses, map, id);
}
}
}
fn eval_requirement(
requirement_states: &mut RequirementMap<RequirementEntry>,
requirements_by_candidate: &mut HashMap<VariableId, SmallVec<Requirement>>,
requirement: Requirement,
map: &DecisionMap,
sorted_candidates: &RequirementMap<Vec<Vec<VariableId>>>,
provider: &D,
#[cfg(feature = "diagnostics")] counters: &mut DecideQueueCounters,
) -> RequirementState {
let entry = requirement_states
.get_mut(requirement)
.expect("every registered clause created a cache entry");
if !matches!(entry.state, RequirementState::Dirty) {
return entry.state;
}
#[cfg(feature = "diagnostics")]
{
counters.walk_evals += 1;
}
let version_set_candidates = &sorted_candidates[requirement];
if !entry.occurrences_registered {
entry.occurrences_registered = true;
for &candidate in version_set_candidates.iter().flatten() {
requirements_by_candidate
.entry(candidate)
.or_insert_with(SmallVec::empty)
.push(requirement);
}
}
let mut first: Option<(VariableId, VersionSetId, u32)> = None;
'walk: for (version_set, candidates) in requirement
.version_sets(provider)
.zip(version_set_candidates)
{
for &candidate in candidates {
match map.value(candidate) {
Some(true) => {
entry.state = RequirementState::Satisfied { by: candidate };
break 'walk;
}
Some(false) => {}
None => match first.as_mut() {
Some((_, first_version_set, count)) => {
if *first_version_set == version_set {
*count += 1;
}
}
None => first = Some((candidate, version_set, 1)),
},
}
}
}
if matches!(entry.state, RequirementState::Dirty) {
let Some((candidate, version_set, count)) = first else {
unreachable!(
"when we get here it means that all candidates have been assigned false. This should not be able to happen at this point because during propagation the solvable should have been assigned false as well."
)
};
entry.state = RequirementState::Frontier {
candidate,
version_set,
count,
};
}
entry.state
}
fn inspect(
&mut self,
clause: TrackedClause,
map: &DecisionMap,
sorted_candidates: &RequirementMap<Vec<Vec<VariableId>>>,
disjunctions: &Arena<DisjunctionId, Disjunction>,
provider: &D,
) -> Option<(VariableId, VersionSetId, u32)> {
#[cfg(feature = "diagnostics")]
{
self.counters.selection_visits += 1;
}
if map.value(clause.parent) != Some(true) {
return None;
}
if let Some(condition) = clause.condition {
let literals = &disjunctions[condition].literals;
if !literals.iter().all(|c| c.eval(map) == Some(false)) {
return None;
}
}
match Self::eval_requirement(
&mut self.requirement_states,
&mut self.requirements_by_candidate,
clause.requirement,
map,
sorted_candidates,
provider,
#[cfg(feature = "diagnostics")]
&mut self.counters,
) {
RequirementState::Satisfied { .. } => None,
RequirementState::Frontier {
candidate,
version_set,
count,
} => Some((candidate, version_set, count)),
RequirementState::Dirty => {
unreachable!("eval_requirement never leaves the entry dirty")
}
}
}
pub(crate) fn next_decision(
&mut self,
map: &DecisionMap,
sorted_candidates: &RequirementMap<Vec<Vec<VariableId>>>,
disjunctions: &Arena<DisjunctionId, Disjunction>,
name_activity: &<D::NameId as SolverId>::Map<f32>,
max_activity: f32,
provider: &D,
) -> Option<QueueDecision> {
struct Best {
position: ClausePosition,
explicit: bool,
activity: f32,
count: u32,
decision: (VariableId, VariableId, ClauseId),
}
let hot_only = self.hot_only;
let unbeatable =
|best: &Best| best.count == 1 || (hot_only && best.activity == max_activity);
let mut best: Option<Best> = None;
while let Some((&position, &id)) = self.queue.first_key_value() {
let clause = self.clauses[id as usize];
match self.inspect(clause, map, sorted_candidates, disjunctions, provider) {
None => {
self.queue.pop_first();
self.hot_queue.remove(&position);
#[cfg(feature = "diagnostics")]
{
self.counters.dequeues += 1;
}
}
Some((candidate, version_set, count)) => {
let activity = name_activity.get(provider.version_set_name(version_set));
best = Some(Best {
position,
explicit: clause.parent == VariableId::root(),
activity,
count,
decision: (candidate, clause.parent, clause.clause_id),
});
break;
}
}
}
let mut best = best?;
if !unbeatable(&best) {
let mut cursor = best.position;
loop {
let next = if hot_only {
self.hot_queue
.range((Bound::Excluded(cursor), Bound::Unbounded))
.next()
} else {
self.queue
.range((Bound::Excluded(cursor), Bound::Unbounded))
.next()
}
.map(|(&position, &id)| (position, id));
let Some((position, id)) = next else {
break;
};
cursor = position;
let clause = self.clauses[id as usize];
let is_explicit = clause.parent == VariableId::root();
if best.explicit && !is_explicit {
continue;
}
#[cfg(feature = "diagnostics")]
{
self.counters.hot_visits += 1;
}
match self.inspect(clause, map, sorted_candidates, disjunctions, provider) {
None => {
self.queue.remove(&position);
self.hot_queue.remove(&position);
#[cfg(feature = "diagnostics")]
{
self.counters.dequeues += 1;
}
}
Some((candidate, version_set, count)) => {
let activity = name_activity.get(provider.version_set_name(version_set));
if best.activity >= activity {
continue;
}
if best.count <= count {
continue;
}
best = Best {
position,
explicit: is_explicit,
activity,
count,
decision: (candidate, clause.parent, clause.clause_id),
};
if unbeatable(&best) {
break;
}
}
}
}
}
let (candidate, required_by, clause_id) = best.decision;
Some(QueueDecision {
candidate,
required_by,
clause_id,
package_activity: best.activity,
candidate_count: best.count,
})
}
#[cfg(debug_assertions)]
pub(crate) fn debug_assert_invariants(
&self,
map: &DecisionMap,
sorted_candidates: &RequirementMap<Vec<Vec<VariableId>>>,
disjunctions: &Arena<DisjunctionId, Disjunction>,
name_activity: &<D::NameId as SolverId>::Map<f32>,
max_activity: f32,
provider: &D,
) {
for (id, clause) in self.clauses.iter().enumerate() {
if map.value(clause.parent) != Some(true) {
continue;
}
if let Some(condition) = clause.condition {
let literals = &disjunctions[condition].literals;
if !literals.iter().all(|c| c.eval(map) == Some(false)) {
continue;
}
}
let satisfied = sorted_candidates[clause.requirement]
.iter()
.flatten()
.any(|&candidate| map.value(candidate) == Some(true));
if satisfied {
continue;
}
assert!(
self.queue.contains_key(&clause.position),
"eligible requires clause {id} is not queued"
);
}
for (id, clause) in self.clauses.iter().enumerate() {
let should_be_hot = clause
.requirement
.version_sets(provider)
.any(|version_set| {
self.hot_names
.contains(provider.version_set_name(version_set))
});
assert_eq!(
clause.hot, should_be_hot,
"clause {id} hot flag out of sync with the hot name set"
);
if clause.hot {
assert_eq!(
self.queue.contains_key(&clause.position),
self.hot_queue.contains_key(&clause.position),
"hot queue out of lockstep for clause {id}"
);
}
}
if self.hot_only {
let mut actual_max = 0.0f32;
name_activity.for_each(|&activity| actual_max = actual_max.max(activity));
assert_eq!(
max_activity, actual_max,
"max_activity diverged from the largest stored activity"
);
}
}
}