use std::cmp;
use rangelist::IntervalIterator;
use rustc_hash::FxHashSet;
use crate::{
Conjunction, IntVal,
actions::{
InitActions, IntEvent, IntInspectionActions, IntPropCond, PostingActions,
PropagationActions, ReasoningContext, ReasoningEngine,
},
constraints::{IntSolverActions, Propagator},
helpers::trailed_partition::TrailedPartition,
solver::{IntLitMeaning, engine::Engine, queue::PriorityLevel},
};
#[derive(Clone, Debug)]
struct AugmentingPathScratch {
queue: Vec<usize>,
parent: Vec<usize>,
}
#[derive(Clone, Debug)]
pub struct IntUniqueDomain<I> {
graph: VariableValueMatching<I>,
dirty_vars: FxHashSet<usize>,
partition: TrailedPartition,
bfs: AugmentingPathScratch,
tarjan: TarjanScratch,
}
#[derive(Clone, Debug)]
struct TarjanFrame {
node: usize,
frame_start: usize,
frame_end: usize,
i: usize,
}
#[derive(Clone, Debug)]
struct TarjanScratch {
dfs_stack: Vec<usize>,
dfs_on_stack: Vec<bool>,
dfs_index: Vec<usize>,
low_link: Vec<usize>,
work_stack: Vec<TarjanFrame>,
neighbours: Vec<usize>,
vars_buf: Vec<usize>,
vals_buf: Vec<usize>,
}
#[derive(Clone, Debug)]
struct VariableValueMatching<I> {
vars: Vec<I>,
union_domain_lb: IntVal,
var_to_val: Vec<Option<usize>>,
val_to_var: Vec<Option<usize>>,
}
impl AugmentingPathScratch {
fn new(n_left: usize) -> Self {
Self {
queue: Vec::new(),
parent: vec![usize::MAX; n_left],
}
}
}
impl<I> IntUniqueDomain<I> {
fn find_augmenting_path<E>(
&mut self,
start_var: usize,
ctx: &mut E::PropagationContext<'_>,
) -> Result<(), E::Conflict>
where
E: ReasoningEngine,
I: IntSolverActions<E>,
{
let matched_val_idx = self.graph.var_to_val[start_var];
if let Some(val_idx) = matched_val_idx {
self.graph.val_to_var[val_idx] = None;
self.graph.var_to_val[start_var] = None;
}
self.bfs.queue.clear();
self.bfs.queue.push(start_var);
self.bfs.parent.fill(usize::MAX);
let mut queue_head = 0;
while queue_head < self.bfs.queue.len() {
let var_idx = self.bfs.queue[queue_head];
queue_head += 1;
for val in self.graph.vars[var_idx].domain(ctx).iter().flatten() {
let val_idx = self.graph.value_index(val);
if let Some(matched_var) = self.graph.val_to_var[val_idx] {
if self.bfs.parent[matched_var] == usize::MAX {
self.bfs.queue.push(matched_var);
self.bfs.parent[matched_var] = var_idx;
}
} else {
self.graph
.augment_along_path(var_idx, val_idx, &self.bfs.parent);
return Ok(());
}
}
}
if let Some(val_idx) = matched_val_idx {
self.graph.val_to_var[val_idx] = Some(start_var);
self.graph.var_to_val[start_var] = Some(val_idx);
}
Err(ctx.declare_conflict(
move |ctx: &mut E::PropagationContext<'_>| -> Vec<<E as ReasoningEngine>::Atom> {
self.build_hall_set_reason(ctx, &self.bfs.queue, |var, ctx, meaning| {
var.lit(ctx, meaning)
})
},
))
}
pub fn post<E>(solver: &mut E, vars: Vec<I>)
where
E: PostingActions + ?Sized,
I: IntSolverActions<Engine> + IntInspectionActions<E>,
{
let graph = VariableValueMatching::new(solver, vars);
let n = graph.n_vars();
let n_nodes = n + graph.n_values();
let partition = TrailedPartition::new(solver, n);
solver.add_propagator(Box::new(Self {
graph,
dirty_vars: FxHashSet::default(),
partition,
bfs: AugmentingPathScratch::new(n),
tarjan: TarjanScratch::new(n_nodes),
}));
}
fn process_scc_root<E>(
&mut self,
start_idx: usize,
ctx: &mut E::PropagationContext<'_>,
) -> Result<(), E::Conflict>
where
E: ReasoningEngine,
I: IntSolverActions<E>,
{
let n_vars = self.graph.n_vars();
let dummy = self.tarjan.dummy_node();
self.tarjan.vars_buf.clear();
self.tarjan.vals_buf.clear();
let mut has_var_in_scc = false;
loop {
let node = self.tarjan.dfs_stack.pop().expect("non-empty DFS stack");
self.tarjan.dfs_on_stack[node] = false;
if node < n_vars {
self.tarjan.vars_buf.push(node);
has_var_in_scc = true;
} else if node != dummy {
self.tarjan.vals_buf.push(node - n_vars);
}
if node == start_idx {
break;
}
}
if !has_var_in_scc {
return Ok(());
}
let (orig_root, new_scc_root) = self.partition.split_off(&self.tarjan.vars_buf, ctx);
let Some(new_root) = new_scc_root else {
return Ok(());
};
let scc_id = new_root;
let val_reason = ctx.deferred_reason(scc_id as u64);
for &val_idx in self.tarjan.vals_buf.iter() {
let val = self.graph.value_at(val_idx);
for pos in orig_root..new_root {
let var_idx = self.partition.elements()[pos];
let var = self.graph.vars[var_idx].clone();
if !var.in_domain(ctx, val) {
continue;
}
var.remove_val(ctx, val, val_reason)?;
}
}
Ok(())
}
fn tarjan_dfs<E>(
&mut self,
start_idx: usize,
next_dfs_index: &mut usize,
n_left_visited: &mut usize,
scc_split_detected: &mut bool,
ctx: &mut E::PropagationContext<'_>,
) -> Result<(), E::Conflict>
where
E: ReasoningEngine,
I: IntSolverActions<E>,
{
let push_frame = |this: &mut Self,
node: usize,
next_dfs_index: &mut usize,
n_left_visited: &mut usize,
ctx: &mut E::PropagationContext<'_>| {
let n_vars = this.graph.n_vars();
if node < n_vars {
*n_left_visited += 1;
}
this.tarjan.dfs_stack.push(node);
this.tarjan.dfs_on_stack[node] = true;
this.tarjan.dfs_index[node] = *next_dfs_index;
this.tarjan.low_link[node] = *next_dfs_index;
*next_dfs_index += 1;
let frame_start = this.tarjan.neighbours.len();
let dummy = this.tarjan.dummy_node();
if node < n_vars {
for val in this.graph.vars[node].domain(ctx).iter().flatten() {
let val_idx = this.graph.value_index(val);
if this.graph.var_to_val[node] == Some(val_idx) {
continue;
}
this.tarjan.neighbours.push(n_vars + val_idx);
}
} else if node == dummy {
for vi in 0..this.graph.n_values() {
if this.graph.val_to_var[vi].is_some() {
this.tarjan.neighbours.push(n_vars + vi);
}
}
} else {
let val_idx = node - n_vars;
if let Some(var_idx) = this.graph.val_to_var[val_idx] {
this.tarjan.neighbours.push(var_idx);
} else {
this.tarjan.neighbours.push(dummy);
}
}
let frame_end = this.tarjan.neighbours.len();
this.tarjan.work_stack.push(TarjanFrame {
node,
frame_start,
frame_end,
i: frame_start,
});
};
let n_vars = self.graph.n_vars();
push_frame(self, start_idx, next_dfs_index, n_left_visited, ctx);
while let Some(&TarjanFrame {
node, frame_end, i, ..
}) = self.tarjan.work_stack.last()
{
if i < frame_end {
let nb = self.tarjan.neighbours[i];
self.tarjan.work_stack.last_mut().unwrap().i += 1;
if self.tarjan.dfs_index[nb] != 0 {
if self.tarjan.dfs_on_stack[nb] {
self.tarjan.low_link[node] =
cmp::min(self.tarjan.low_link[node], self.tarjan.dfs_index[nb]);
}
} else {
push_frame(self, nb, next_dfs_index, n_left_visited, ctx);
}
continue;
}
let frame = self.tarjan.work_stack.pop().unwrap();
self.tarjan.neighbours.truncate(frame.frame_start);
if self.tarjan.low_link[frame.node] == self.tarjan.dfs_index[frame.node] {
if self.tarjan.low_link[frame.node] > 1 || *n_left_visited < n_vars {
*scc_split_detected = true;
}
if *scc_split_detected {
self.process_scc_root::<E>(frame.node, ctx)?;
}
}
if let Some(&TarjanFrame { node: parent, .. }) = self.tarjan.work_stack.last() {
self.tarjan.low_link[parent] = cmp::min(
self.tarjan.low_link[parent],
self.tarjan.low_link[frame.node],
);
}
}
Ok(())
}
}
impl<I> IntUniqueDomain<I> {
fn build_hall_set_reason<C, A, F>(
&self,
ctx: &mut C,
members: &[usize],
mut get_lit: F,
) -> Vec<A>
where
C: ReasoningContext,
I: IntInspectionActions<C>,
F: FnMut(&I, &mut C, IntLitMeaning) -> A,
{
let mut dom_lb = IntVal::MAX;
let mut dom_ub = IntVal::MIN;
for &vid in members {
let (lb, ub) = self.graph.vars[vid].bounds(ctx);
dom_lb = cmp::min(dom_lb, lb);
dom_ub = cmp::max(dom_ub, ub);
}
let window = (dom_ub - dom_lb + 1) as usize;
let mut union_bits = FxHashSet::default();
for &vid in members {
for val in self.graph.vars[vid].domain(ctx).iter().flatten() {
union_bits.insert((val - dom_lb) as usize);
}
}
let n_holes = window - union_bits.len();
let mut reason: Vec<A> = Vec::with_capacity(members.len() * (2 + n_holes));
for &vid in members {
let var = &self.graph.vars[vid];
reason.push(get_lit(var, ctx, IntLitMeaning::GreaterEq(dom_lb)));
reason.push(get_lit(var, ctx, IntLitMeaning::Less(dom_ub + 1)));
for i in dom_lb..=dom_ub {
if !union_bits.contains(&((i - dom_lb) as usize)) {
reason.push(get_lit(var, ctx, IntLitMeaning::NotEq(i)));
}
}
}
reason
}
fn repair_matching_and_propagate_fixed<E>(
&mut self,
dirty: &FxHashSet<usize>,
ctx: &mut E::PropagationContext<'_>,
) -> Result<FxHashSet<usize>, E::Conflict>
where
E: ReasoningEngine,
I: IntSolverActions<E>,
{
let mut changed_scc = FxHashSet::default();
for &i in dirty.iter() {
let scc_id = self.partition.block_root(i, ctx);
let needs_augment = match self.graph.var_to_val[i] {
None => true,
Some(val_idx) => !self.graph.vars[i].in_domain(ctx, self.graph.value_at(val_idx)),
};
if needs_augment {
self.find_augmenting_path::<E>(i, ctx)?;
}
if let Some(val) = self.graph.vars[i].val(ctx) {
changed_scc.insert(scc_id);
let (orig_scc, new_scc) = self.partition.split_off(&[i], ctx);
if new_scc.is_some() {
let orig_scc_end = self.partition.block_end(orig_scc, ctx);
let reason_lit = self.graph.vars[i].lit(ctx, IntLitMeaning::Eq(val));
for pos in orig_scc..orig_scc_end {
let idx = self.partition.elements()[pos];
let v = self.graph.vars[idx].clone();
v.remove_val(ctx, val, [reason_lit.clone()].as_slice())?;
}
if orig_scc_end - orig_scc > 1 {
changed_scc.insert(orig_scc);
}
}
} else {
let scc_end = self.partition.block_end(scc_id, ctx);
if scc_end - scc_id > 1 {
changed_scc.insert(scc_id);
}
}
}
Ok(changed_scc)
}
fn run_tarjan_on_changed_sccs<E>(
&mut self,
changed_scc: &FxHashSet<usize>,
ctx: &mut E::PropagationContext<'_>,
) -> Result<(), E::Conflict>
where
E: ReasoningEngine,
I: IntSolverActions<E>,
{
self.tarjan.reset();
let mut next_dfs_index: usize = 1;
let mut n_left_visited: usize = 0;
let mut scc_split_detected = false;
for &i in changed_scc.iter() {
let scc_end = self.partition.block_end(i, ctx);
for var_idx in i..scc_end {
if self.tarjan.dfs_index[var_idx] == 0 {
self.tarjan_dfs::<E>(
var_idx,
&mut next_dfs_index,
&mut n_left_visited,
&mut scc_split_detected,
ctx,
)?;
}
}
}
Ok(())
}
}
impl<E, I> Propagator<E> for IntUniqueDomain<I>
where
E: ReasoningEngine,
I: IntSolverActions<E>,
{
fn advise_of_backtrack(&mut self, _: &mut E::NotificationContext<'_>) {
self.dirty_vars.clear();
}
fn advise_of_int_change(
&mut self,
ctx: &mut E::NotificationContext<'_>,
data: u64,
_event: IntEvent,
) -> bool {
let domain_size = self.graph.vars[data as usize].domain(ctx).card().unwrap();
self.dirty_vars.insert(data as usize);
domain_size < self.graph.n_vars()
}
fn explain(
&mut self,
ctx: &mut E::ExplanationContext<'_>,
_lit: E::Atom,
data: u64,
) -> Conjunction<E::Atom> {
let scc_id = data as usize;
let scc_end = self.partition.block_end(scc_id, ctx);
self.build_hall_set_reason(
ctx,
&self.partition.elements()[scc_id..scc_end],
|var, ctx, meaning| {
let (atom, _) = var.lit_relaxed(ctx, meaning);
atom
},
)
}
fn initialize(&mut self, ctx: &mut E::InitializationContext<'_>) {
for i in 0..self.graph.n_vars() {
self.dirty_vars.insert(i);
}
ctx.set_priority(PriorityLevel::Low);
for (i, v) in self.graph.vars.iter().enumerate() {
v.advise_when(ctx, IntPropCond::Domain, i as u64);
}
ctx.advise_on_backtrack();
ctx.enqueue_now(true);
}
#[tracing::instrument(
name = "int_unique_domain",
target = "solver",
level = "trace",
skip(self, ctx)
)]
fn propagate(&mut self, ctx: &mut E::PropagationContext<'_>) -> Result<(), E::Conflict> {
let mut dirty = std::mem::take(&mut self.dirty_vars);
let result = self.repair_matching_and_propagate_fixed(&dirty, ctx);
dirty.clear();
self.dirty_vars = dirty;
let changed_scc = result?;
self.run_tarjan_on_changed_sccs(&changed_scc, ctx)
}
}
impl TarjanScratch {
fn dummy_node(&self) -> usize {
self.dfs_on_stack.len() - 1
}
fn new(n_nodes: usize) -> Self {
let n_slots = n_nodes + 1;
Self {
dfs_stack: Vec::new(),
dfs_on_stack: vec![false; n_slots],
dfs_index: vec![0; n_slots],
low_link: vec![0; n_slots],
work_stack: Vec::new(),
neighbours: Vec::new(),
vars_buf: Vec::new(),
vals_buf: Vec::new(),
}
}
fn reset(&mut self) {
self.dfs_stack.clear();
self.dfs_on_stack.fill(false);
self.dfs_index.fill(0);
self.low_link.fill(0);
self.work_stack.clear();
self.neighbours.clear();
}
}
impl<I> VariableValueMatching<I> {
fn augment_along_path(&mut self, end_var: usize, end_val: usize, bfs_parent: &[usize]) {
let mut cur_var = end_var;
let mut cur_val = end_val;
loop {
let prev_val = self.var_to_val[cur_var];
self.val_to_var[cur_val] = Some(cur_var);
self.var_to_val[cur_var] = Some(cur_val);
let Some(pv) = prev_val else {
break;
};
cur_val = pv;
let parent = bfs_parent[cur_var];
debug_assert_ne!(parent, usize::MAX, "BFS parent missing");
cur_var = parent;
}
}
fn n_values(&self) -> usize {
self.val_to_var.len()
}
fn n_vars(&self) -> usize {
self.vars.len()
}
fn new<C: ReasoningContext + ?Sized>(ctx: &mut C, vars: Vec<I>) -> Self
where
I: IntInspectionActions<C>,
{
let n = vars.len();
let mut lb = IntVal::MAX;
let mut ub = IntVal::MIN;
for v in &vars {
let (l, u) = v.bounds(ctx);
lb = cmp::min(lb, l);
ub = cmp::max(ub, u);
}
debug_assert!(lb <= ub);
Self {
vars,
union_domain_lb: lb,
var_to_val: vec![None; n],
val_to_var: vec![None; (ub - lb + 1) as usize],
}
}
fn value_at(&self, right_idx: usize) -> IntVal {
self.union_domain_lb + right_idx as IntVal
}
fn value_index(&self, val: IntVal) -> usize {
(val - self.union_domain_lb) as usize
}
}
#[cfg(test)]
mod tests {
use itertools::Itertools;
use tracing_test::traced_test;
use crate::{
IntSet, IntVal,
constraints::int_unique::IntUniqueDomain,
solver::{LiteralStrategy, Solver},
};
#[test]
#[traced_test]
fn test_all_different_domain_deep_chain() {
const N: IntVal = 300;
let mut slv = Solver::default();
let vars: Vec<_> = (1..=N)
.map(|i| {
let dom = if i == N { N..=N } else { i..=i + 1 };
slv.new_int_decision(dom)
.order_literals(LiteralStrategy::Eager)
.direct_literals(LiteralStrategy::Eager)
.view()
})
.collect();
IntUniqueDomain::post(&mut slv, vars.clone());
slv.assert_all_solutions(&vars, |sol| {
sol.iter()
.enumerate()
.all(|(i, v)| *v == crate::solver::Value::Int(i as IntVal + 1))
});
}
#[test]
#[traced_test]
fn test_all_different_domain_filtering() {
let mut slv = Solver::default();
let a = slv
.new_int_decision(1..=2)
.order_literals(LiteralStrategy::Eager)
.direct_literals(LiteralStrategy::Eager)
.view();
let b = slv
.new_int_decision(1..=2)
.order_literals(LiteralStrategy::Eager)
.direct_literals(LiteralStrategy::Eager)
.view();
let c = slv
.new_int_decision(1..=3)
.order_literals(LiteralStrategy::Eager)
.direct_literals(LiteralStrategy::Eager)
.view();
IntUniqueDomain::post(&mut slv, vec![a, b, c]);
slv.assert_all_solutions(&[a, b, c], |sol| {
sol.iter().all_unique() && sol[2] == crate::solver::Value::Int(3)
});
}
#[test]
#[traced_test]
fn test_all_different_domain_interior_hole() {
use crate::actions::{IntDecisionActions, IntInspectionActions};
let mut slv = Solver::default();
let a = slv
.new_int_decision(3..=4)
.order_literals(LiteralStrategy::Eager)
.direct_literals(LiteralStrategy::Eager)
.view();
let b = slv
.new_int_decision(3..=4)
.order_literals(LiteralStrategy::Eager)
.direct_literals(LiteralStrategy::Eager)
.view();
let e = slv
.new_int_decision(1..=6)
.order_literals(LiteralStrategy::Eager)
.direct_literals(LiteralStrategy::Eager)
.view();
IntUniqueDomain::post(&mut slv, vec![a, b, e]);
let propagated = slv.propagate_next().unwrap();
assert_eq!(a.domain(&slv), IntSet::from(3..=4));
assert_eq!(b.domain(&slv), IntSet::from(3..=4));
assert_eq!(e.domain(&slv), IntSet::from_iter([1..=2, 5..=6]));
assert!(!e.in_domain(&slv, 3));
assert!(!e.in_domain(&slv, 4));
assert_eq!(e.bounds(&slv), (1, 6));
let expected = [
e.lit(&mut slv, crate::solver::IntLitMeaning::NotEq(3)),
e.lit(&mut slv, crate::solver::IntLitMeaning::NotEq(4)),
];
assert_eq!(propagated.len(), expected.len());
for lit in expected {
assert!(
propagated.contains(&lit),
"missing propagated literal {lit:?}"
);
}
}
#[test]
#[traced_test]
fn test_all_different_domain_sat() {
let mut slv = Solver::default();
let a = slv
.new_int_decision(1..=3)
.order_literals(LiteralStrategy::Eager)
.direct_literals(LiteralStrategy::Eager)
.view();
let b = slv
.new_int_decision(1..=3)
.order_literals(LiteralStrategy::Eager)
.direct_literals(LiteralStrategy::Eager)
.view();
let c = slv
.new_int_decision(1..=3)
.order_literals(LiteralStrategy::Eager)
.direct_literals(LiteralStrategy::Eager)
.view();
IntUniqueDomain::post(&mut slv, vec![a, b, c]);
slv.assert_all_solutions(&[a, b, c], |sol| sol.iter().all_unique());
}
#[test]
#[traced_test]
fn test_all_different_domain_unsat() {
let mut slv = Solver::default();
let a = slv
.new_int_decision(1..=2)
.order_literals(LiteralStrategy::Eager)
.direct_literals(LiteralStrategy::Eager)
.view();
let b = slv
.new_int_decision(1..=2)
.order_literals(LiteralStrategy::Eager)
.direct_literals(LiteralStrategy::Eager)
.view();
let c = slv
.new_int_decision(1..=2)
.order_literals(LiteralStrategy::Eager)
.direct_literals(LiteralStrategy::Eager)
.view();
IntUniqueDomain::post(&mut slv, vec![a, b, c]);
slv.assert_unsatisfiable();
}
}