use std::iter::once;
use crate::{
IntVal,
actions::{
BoolInitActions, BoolSimplificationActions, IntDecisionActions, IntInitActions,
IntInspectionActions, IntPropCond, IntPropagationActions, ReasoningEngine,
},
constraints::{
BoolModelActions, Constraint, IntModelActions, Propagator, SimplificationStatus,
},
lower::{LoweringContext, LoweringError},
model::view::View,
solver::IntLitMeaning,
};
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct BoolDecisionArrayElement {
pub(crate) array: Vec<View<bool>>,
pub(crate) index: View<IntVal>,
pub(crate) result: View<bool>,
}
impl<E> Constraint<E> for BoolDecisionArrayElement
where
E: ReasoningEngine,
View<IntVal>: IntModelActions<E>,
View<bool>: BoolModelActions<E>,
{
fn simplify(
&mut self,
ctx: &mut E::PropagationContext<'_>,
) -> Result<SimplificationStatus, E::Conflict> {
Self::propagate(self, ctx)?;
if let Some(i) = self.index.val(ctx) {
self.array[i as usize].unify(ctx, self.result)?;
return Ok(SimplificationStatus::Subsumed);
}
Ok(SimplificationStatus::NoFixpoint)
}
fn to_solver(&self, slv: &mut LoweringContext<'_>) -> Result<(), LoweringError> {
let result = slv.solver_view(self.result);
let index = slv.solver_view(self.index);
let arr: Vec<_> = self.array.iter().map(|&v| slv.solver_view(v)).collect();
for (i, &l) in arr.iter().enumerate() {
let idx_eq = index.lit(slv, IntLitMeaning::Eq(i as IntVal));
slv.add_clause([!idx_eq, !l, result])?;
slv.add_clause([!idx_eq, l, !result])?;
}
slv.add_clause(arr.iter().map(|&l| !l).chain(once(result)))?;
slv.add_clause(arr.into_iter().chain(once(!result)))?;
Ok(())
}
}
impl<E> Propagator<E> for BoolDecisionArrayElement
where
E: ReasoningEngine,
View<IntVal>: IntModelActions<E>,
View<bool>: BoolModelActions<E>,
{
fn initialize(&mut self, ctx: &mut E::InitializationContext<'_>) {
for &b in &self.array {
b.enqueue_when_fixed(ctx);
}
self.index.enqueue_when(ctx, IntPropCond::Fixed);
self.result.enqueue_when_fixed(ctx);
}
fn propagate(&mut self, ctx: &mut E::PropagationContext<'_>) -> Result<(), E::Conflict> {
self.index.tighten_min(ctx, 0, vec![])?;
self.index
.tighten_max(ctx, self.array.len() as IntVal - 1, vec![])?;
Ok(())
}
}