use std::iter::once;
use itertools::Itertools;
use rangelist::IntervalIterator;
use rustc_hash::FxHashMap;
use crate::{
IntSet, IntVal,
actions::{
ConstructionActions, InitActions, IntDecisionActions, IntInspectionActions, IntPropCond,
IntSimplificationActions, PostingActions, ReasoningContext, ReasoningEngine,
SimplificationActions, Trailed, TrailingActions,
},
constraints::{
Constraint, IntModelActions, IntSolverActions, Propagator, SimplificationStatus,
},
lower::{LoweringContext, LoweringError},
model::View,
solver::{IntLitMeaning, engine::Engine, queue::PriorityLevel},
};
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct IntArrayElementBounds<I1, I2, I3> {
vars: Vec<I1>,
pub(crate) index: I2,
result: I3,
min_support: Trailed<usize>,
max_support: Trailed<usize>,
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct IntValArrayElement<I1, I2>(pub(crate) IntArrayElementBounds<IntVal, I1, I2>);
impl<I1, I2, I3> IntArrayElementBounds<I1, I2, I3> {
pub(crate) fn new<E>(engine: &mut E, collection: Vec<I1>, index: I2, result: I3) -> Self
where
E: ConstructionActions + ReasoningContext + ?Sized,
I1: IntInspectionActions<E>,
I2: IntInspectionActions<E>,
{
let mut min_support = None;
let mut max_support = None;
let mut min_lb = IntVal::MAX;
let mut max_ub = IntVal::MIN;
for i in index
.domain(engine)
.iter()
.flatten()
.filter(|&v| v >= 0 && v < collection.len() as IntVal)
{
let i = i as usize;
let (lb, ub) = collection[i].bounds(engine);
if min_support.is_none() || lb < min_lb {
min_support = Some(i);
min_lb = lb;
}
if max_support.is_none() || ub > max_ub {
max_support = Some(i);
max_ub = ub;
}
}
let min_support = engine.new_trailed(min_support.unwrap());
let max_support = engine.new_trailed(max_support.unwrap());
Self {
vars: collection.clone(),
result,
index,
min_support,
max_support,
}
}
pub fn post<E>(
solver: &mut E,
collection: Vec<I1>,
index: I2,
result: I3,
) -> Result<(), E::Conflict>
where
E: PostingActions + ?Sized,
I1: IntSolverActions<Engine> + IntDecisionActions<E>,
I2: IntSolverActions<Engine> + IntDecisionActions<E>,
I3: IntSolverActions<Engine>,
{
let index_ub = index.lit(solver, IntLitMeaning::Less(collection.len() as IntVal));
let index_lb = index.lit(solver, IntLitMeaning::GreaterEq(0));
solver.add_clause([index_ub])?;
solver.add_clause([index_lb])?;
let me = Self::new(solver, collection, index, result);
solver.add_propagator(Box::new(me));
Ok(())
}
}
impl<E, I1, I2, I3> Constraint<E> for IntArrayElementBounds<I1, I2, I3>
where
E: ReasoningEngine,
for<'a> E::PropagationContext<'a>: SimplificationActions<Target = E>,
I1: IntModelActions<E>,
I2: IntModelActions<E>,
I3: IntModelActions<E>,
View<IntVal>: IntModelActions<E>,
IntVal: IntModelActions<E>,
{
fn simplify(
&mut self,
ctx: &mut E::PropagationContext<'_>,
) -> Result<SimplificationStatus, E::Conflict> {
self.index.tighten_min(ctx, 0, [])?;
self.index
.tighten_max(ctx, self.vars.len() as IntVal - 1, [])?;
self.propagate(ctx)?;
if let Some(i) = self.index.val(ctx) {
self.vars[i as usize]
.clone()
.into()
.unify(ctx, self.result.clone())?;
return Ok(SimplificationStatus::Subsumed);
} else if self.vars.iter().all(|v| v.val(ctx).is_some()) {
let vars = self.vars.iter().map(|v| v.val(ctx).unwrap()).collect_vec();
let rewrite = IntValArrayElement(IntArrayElementBounds {
vars,
index: self.index.clone(),
result: self.result.clone(),
min_support: self.min_support,
max_support: self.max_support,
});
ctx.post_constraint(rewrite);
return Ok(SimplificationStatus::Subsumed);
}
Ok(SimplificationStatus::NoFixpoint)
}
fn to_solver(&self, ctx: &mut LoweringContext<'_>) -> Result<(), LoweringError> {
let array = self
.vars
.iter()
.map(|v| ctx.solver_view(v.clone().into()))
.collect();
let result = ctx.solver_view(self.result.clone().into());
let index = ctx.solver_view(self.index.clone().into());
IntArrayElementBounds::post(ctx, array, index, result).unwrap();
Ok(())
}
}
impl<E, I1, I2, I3> Propagator<E> for IntArrayElementBounds<I1, I2, I3>
where
E: ReasoningEngine,
I1: IntSolverActions<E>,
I2: IntSolverActions<E>,
I3: IntSolverActions<E>,
{
fn initialize(&mut self, ctx: &mut E::InitializationContext<'_>) {
ctx.set_priority(PriorityLevel::Low);
self.result.enqueue_when(ctx, IntPropCond::Bounds);
self.index.enqueue_when(ctx, IntPropCond::Domain);
for i in self
.index
.domain(ctx)
.iter()
.flatten()
.filter(|&v| v >= 0 && v < self.vars.len() as IntVal)
{
self.vars[i as usize].enqueue_when(ctx, IntPropCond::Bounds);
}
}
#[tracing::instrument(
name = "int_array_element_bounds",
target = "solver",
level = "trace",
skip(self, ctx)
)]
fn propagate(&mut self, ctx: &mut E::PropagationContext<'_>) -> Result<(), E::Conflict> {
if let Some(fixed_index) = self.index.val(ctx) {
let index_val_lit = self.index.val_lit(ctx).unwrap();
let fixed_var = &self.vars[fixed_index as usize];
self.result.tighten_min(
ctx,
fixed_var.min(ctx),
|ctx: &mut E::PropagationContext<'_>| {
[index_val_lit.clone(), fixed_var.min_lit(ctx)]
},
)?;
fixed_var.tighten_min(
ctx,
self.result.min(ctx),
|ctx: &mut E::PropagationContext<'_>| {
[index_val_lit.clone(), self.result.min_lit(ctx)]
},
)?;
self.result.tighten_max(
ctx,
fixed_var.max(ctx),
|ctx: &mut E::PropagationContext<'_>| {
[index_val_lit.clone(), fixed_var.max_lit(ctx)]
},
)?;
fixed_var.tighten_max(
ctx,
self.result.max(ctx),
|ctx: &mut E::PropagationContext<'_>| {
[index_val_lit.clone(), self.result.max_lit(ctx)]
},
)?;
return Ok(());
}
let (result_lb, result_ub) = self.result.bounds(ctx);
let idx_dom: IntSet = self.index.domain(ctx);
let min_support = ctx.trailed(self.min_support);
let max_support = ctx.trailed(self.max_support);
let old_min = self.vars[min_support].min(ctx);
let old_max = self.vars[max_support].max(ctx);
let mut need_min_support =
old_min > result_lb || !idx_dom.contains(&(min_support as IntVal));
let mut need_max_support =
old_max < result_ub || !idx_dom.contains(&(max_support as IntVal));
let mut new_min_support = min_support;
let mut new_max_support = max_support;
let mut new_min = if need_min_support {
IntVal::MAX
} else {
old_min
};
let mut new_max = if need_max_support {
IntVal::MIN
} else {
old_max
};
for i in idx_dom.iter().flatten() {
debug_assert!(i >= 0 && i <= self.vars.len() as IntVal);
let i = i as usize;
let v = &self.vars[i];
let (v_lb, v_ub) = v.bounds(ctx);
if result_ub < v_lb {
self.index.remove_val(
ctx,
i as IntVal,
|ctx: &mut E::PropagationContext<'_>| {
[
self.result.lit(ctx, IntLitMeaning::Less(v_lb)),
v.min_lit(ctx),
]
},
)?;
}
if v_ub < result_lb {
self.index.remove_val(
ctx,
i as IntVal,
|ctx: &mut E::PropagationContext<'_>| {
[
self.result.lit(ctx, IntLitMeaning::GreaterEq(v_ub + 1)),
v.max_lit(ctx),
]
},
)?;
}
if need_min_support && v_lb < new_min {
new_min_support = i;
new_min = v_lb;
need_min_support = new_min > result_lb;
}
if need_max_support && v_ub > new_max {
new_max_support = i;
new_max = v_ub;
need_max_support = new_max < result_ub;
}
}
ctx.set_trailed(self.min_support, new_min_support);
ctx.set_trailed(self.max_support, new_max_support);
if new_min > result_lb {
self.result
.tighten_min(ctx, new_min, |ctx: &mut E::PropagationContext<'_>| {
let mut reason = Vec::with_capacity(self.vars.len());
let dom = self.index.domain(ctx);
let mut dom = dom.iter().flatten().peekable();
for (i, v) in self.vars.iter().enumerate() {
debug_assert!(dom.peek().is_none() || *dom.peek().unwrap() >= i as IntVal);
if dom.peek() == Some(&(i as IntVal)) {
reason.push(v.lit(ctx, IntLitMeaning::GreaterEq(new_min)));
dom.next();
} else {
reason.push(self.index.lit(ctx, IntLitMeaning::NotEq(i as IntVal)));
}
}
reason
})?;
}
if new_max < result_ub {
self.result
.tighten_max(ctx, new_max, |ctx: &mut E::PropagationContext<'_>| {
let mut reason = Vec::with_capacity(self.vars.len());
let dom = self.index.domain(ctx);
let mut dom = dom.iter().flatten().peekable();
for (i, v) in self.vars.iter().enumerate() {
debug_assert!(dom.peek().is_none() || *dom.peek().unwrap() >= i as IntVal);
if dom.peek() == Some(&(i as IntVal)) {
reason.push(v.lit(ctx, IntLitMeaning::Less(new_max + 1)));
dom.next();
} else {
reason.push(self.index.lit(ctx, IntLitMeaning::NotEq(i as IntVal)));
}
}
reason
})?;
}
Ok(())
}
}
impl<E, I1, I2> Constraint<E> for IntValArrayElement<I1, I2>
where
E: ReasoningEngine,
for<'a> E::PropagationContext<'a>: SimplificationActions<Target = E>,
I1: IntModelActions<E>,
I2: IntModelActions<E>,
IntVal: IntModelActions<E>,
View<IntVal>: IntModelActions<E>,
{
fn simplify(
&mut self,
ctx: &mut E::PropagationContext<'_>,
) -> Result<SimplificationStatus, E::Conflict> {
self.0.index.tighten_min(ctx, 0, [])?;
self.0
.index
.tighten_max(ctx, self.0.vars.len() as IntVal - 1, [])?;
self.0.propagate(ctx)?;
if let Some(i) = self.0.index.val(ctx) {
self.0
.result
.clone()
.into()
.unify(ctx, self.0.vars[i as usize])?;
return Ok(SimplificationStatus::Subsumed);
}
Ok(SimplificationStatus::NoFixpoint)
}
fn to_solver(&self, slv: &mut LoweringContext<'_>) -> Result<(), LoweringError> {
let index = slv.solver_view(self.0.index.clone().into());
let result = slv.solver_view(self.0.result.clone().into());
let mut idx_map = FxHashMap::default();
self.0.vars.iter().enumerate().for_each(|(idx, &val)| {
idx_map
.entry(val)
.or_insert_with(Vec::new)
.push(idx as IntVal);
});
#[expect(clippy::iter_over_hash_type, reason = "FxHashMap::iter is stable")]
for (val, idxs) in idx_map {
let val_eq = result.lit(slv, IntLitMeaning::Eq(val));
let idxs: Vec<_> = idxs
.iter()
.map(|&i| index.lit(slv, IntLitMeaning::Eq(i)))
.collect();
for &i in idxs.iter() {
slv.add_clause([!i, val_eq])?;
}
slv.add_clause(idxs.into_iter().chain(once(!val_eq)))?;
}
Ok(())
}
}
impl<E, I1, I2> Propagator<E> for IntValArrayElement<I1, I2>
where
E: ReasoningEngine,
I1: IntSolverActions<E>,
I2: IntSolverActions<E>,
IntVal: IntSolverActions<E>,
{
fn initialize(&mut self, ctx: &mut E::InitializationContext<'_>) {
self.0.initialize(ctx);
}
fn propagate(&mut self, _: &mut E::PropagationContext<'_>) -> Result<(), E::Conflict> {
unreachable!()
}
}
#[cfg(test)]
mod tests {
use expect_test::expect;
use tracing_test::traced_test;
use crate::{
IntSet,
actions::{IntInspectionActions, IntPropagationActions},
constraints::int_array_element::IntArrayElementBounds,
model::Model,
solver::{LiteralStrategy, Solver},
};
#[test]
fn recompute_max_support_after_index_pruning() {
let mut prb = Model::default();
let a = prb.new_int_decision(1..=1);
let b = prb.new_int_decision(3..=3);
let c = prb.new_int_decision(4..=4);
let result = prb.new_int_decision(1..=4);
let index = prb.new_int_decision(0..=2);
prb.element(vec![a, b, c])
.index(index)
.result(result)
.post()
.unwrap();
index.remove_val(&mut prb, 2, []).unwrap();
let _: (Solver, _) = prb.lower().to_solver().unwrap();
assert_eq!(index.domain(&prb), (0..=1).into());
assert_eq!(result.bounds(&prb), (1, 3));
}
#[test]
fn recompute_min_support_after_index_pruning() {
let mut prb = Model::default();
let a = prb.new_int_decision(1..=1);
let b = prb.new_int_decision(3..=3);
let c = prb.new_int_decision(4..=4);
let result = prb.new_int_decision(2..=4);
let index = prb.new_int_decision(0..=2);
prb.element(vec![a, b, c])
.index(index)
.result(result)
.post()
.unwrap();
index.remove_val(&mut prb, 0, []).unwrap();
let _: (Solver, _) = prb.lower().to_solver().unwrap();
assert_eq!(index.domain(&prb), (1..=2).into());
assert_eq!(result.bounds(&prb), (3, 4));
}
#[test]
#[traced_test]
fn test_element_bounds_sat() {
let mut slv = Solver::default();
let a = slv
.new_int_decision(3..=4)
.order_literals(LiteralStrategy::Eager)
.view();
let b = slv
.new_int_decision(2..=3)
.order_literals(LiteralStrategy::Eager)
.view();
let c = slv
.new_int_decision(4..=5)
.order_literals(LiteralStrategy::Eager)
.view();
let y = slv
.new_int_decision(3..=4)
.order_literals(LiteralStrategy::Eager)
.view();
let index = slv
.new_int_decision(0..=2)
.order_literals(LiteralStrategy::Eager)
.view();
IntArrayElementBounds::post(&mut slv, vec![a, b, c], index, y).unwrap();
slv.expect_solutions(
&[index, y, a, b, c],
expect![[r#"
0, 3, 3, 2, 4
0, 3, 3, 2, 5
0, 3, 3, 3, 4
0, 3, 3, 3, 5
0, 4, 4, 2, 4
0, 4, 4, 2, 5
0, 4, 4, 3, 4
0, 4, 4, 3, 5
1, 3, 3, 3, 4
1, 3, 3, 3, 5
1, 3, 4, 3, 4
1, 3, 4, 3, 5
2, 4, 3, 2, 4
2, 4, 3, 3, 4
2, 4, 4, 2, 4
2, 4, 4, 3, 4"#]],
);
}
#[test]
#[traced_test]
fn test_element_holes() {
let mut slv = Solver::default();
let a = slv
.new_int_decision(1..=3)
.order_literals(LiteralStrategy::Eager)
.view();
let b = slv
.new_int_decision(1..=3)
.order_literals(LiteralStrategy::Eager)
.view();
let y = slv
.new_int_decision(3..=4)
.order_literals(LiteralStrategy::Eager)
.view();
let index = slv
.new_int_decision(IntSet::from_iter([0..=0, 3..=3]))
.order_literals(LiteralStrategy::Eager)
.view();
IntArrayElementBounds::post(&mut slv, vec![a, b], index, y).unwrap();
slv.expect_solutions(
&[index, y, a, b],
expect![[r#"
0, 3, 3, 1
0, 3, 3, 2
0, 3, 3, 3"#]],
);
}
#[test]
#[traced_test]
fn test_element_unsat() {
let mut prb = Model::default();
let a = prb.new_int_decision(3..=5);
let b = prb.new_int_decision(4..=5);
let c = prb.new_int_decision(4..=10);
let result = prb.new_int_decision(1..=2);
let index = prb.new_int_decision(0..=2);
assert!(
prb.element(vec![a, b, c])
.index(index)
.result(result)
.post()
.is_err()
);
}
}