use std::{
num::NonZero,
ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
};
use crate::{
IntSet, IntVal,
actions::{
BoolPropagationActions, IntDecisionActions, IntExplanationActions, IntInspectionActions,
IntPropagationActions, IntSimplificationActions, PropagationActions, ReasoningContext,
},
constraints::{Conflict, ReasonBuilder, int_linear::IntEq},
model::{
Decision, Model, View,
expressions::linear::IntLinearExp,
resolved::Resolved,
view::{DefaultView, boolean::BoolView, private},
},
solver::IntLitMeaning,
views::{LinearBoolView, LinearView},
};
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[non_exhaustive]
pub enum IntView {
Const(IntVal),
Linear(LinearView<NonZero<IntVal>, IntVal, Decision<IntVal>>),
Bool(LinearBoolView<NonZero<IntVal>, IntVal, View<bool>>),
}
impl DefaultView for IntVal {
type View = IntView;
}
impl private::Sealed for IntVal {}
impl Resolved<View<IntVal>> {
pub(crate) fn exclude(
self,
ctx: &mut Model,
values: &IntSet,
reason: impl ReasonBuilder<Model>,
) -> Result<(), Conflict<View<bool>>> {
match self.0.0 {
IntView::Const(v) => v.exclude(ctx, values, reason),
IntView::Linear(lin) => {
Resolved(lin.var).exclude(ctx, &lin.reverse_intset(values), reason)
}
IntView::Bool(lin) => lin.exclude(ctx, values, reason),
}
}
pub(crate) fn fix(
self,
ctx: &mut Model,
val: IntVal,
reason: impl ReasonBuilder<Model>,
) -> Result<(), Conflict<View<bool>>> {
match self.0.0 {
IntView::Const(v) => v.fix(ctx, val, reason),
IntView::Linear(lin) => {
let Some(val) = lin.try_reverse_val(val) else {
return Err(ctx.declare_conflict(reason));
};
Resolved(lin.var).fix(ctx, val, reason)
}
IntView::Bool(lin) => lin.fix(ctx, val, reason),
}
}
pub(crate) fn remove_val(
self,
ctx: &mut Model,
val: IntVal,
reason: impl ReasonBuilder<Model>,
) -> Result<(), Conflict<View<bool>>> {
match self.0.0 {
IntView::Const(v) => v.remove_val(ctx, val, reason),
IntView::Linear(lin) => {
let Some(val) = lin.try_reverse_val(val) else {
return Ok(());
};
Resolved(lin.var).remove_val(ctx, val, reason)
}
IntView::Bool(lin) => lin.remove_val(ctx, val, reason),
}
}
pub(crate) fn restrict_domain(
self,
ctx: &mut Model,
values: &IntSet,
reason: impl ReasonBuilder<Model>,
) -> Result<(), Conflict<View<bool>>> {
match self.0.0 {
IntView::Const(v) => v.restrict_domain(ctx, values, reason),
IntView::Linear(lin) => {
Resolved(lin.var).restrict_domain(ctx, &lin.reverse_intset(values), reason)
}
IntView::Bool(lin) => lin.restrict_domain(ctx, values, reason),
}
}
pub(crate) fn tighten_max(
self,
ctx: &mut Model,
ub: IntVal,
reason: impl ReasonBuilder<Model>,
) -> Result<(), Conflict<View<bool>>> {
match self.0.0 {
IntView::Const(v) => v.tighten_max(ctx, ub, reason),
IntView::Linear(lin) => {
if lin.scale.get() >= 0 {
Resolved(lin.var).tighten_max(ctx, lin.reverse_val_floor(ub), reason)
} else {
Resolved(lin.var).tighten_min(ctx, lin.reverse_val_ceil(ub), reason)
}
}
IntView::Bool(lin) => lin.tighten_max(ctx, ub, reason),
}
}
pub(crate) fn tighten_min(
self,
ctx: &mut Model,
val: IntVal,
reason: impl ReasonBuilder<Model>,
) -> Result<(), Conflict<View<bool>>> {
match self.0.0 {
IntView::Const(v) => v.tighten_min(ctx, val, reason),
IntView::Linear(lin) => {
if lin.scale.get() >= 0 {
Resolved(lin.var).tighten_min(ctx, lin.reverse_val_ceil(val), reason)
} else {
Resolved(lin.var).tighten_max(ctx, lin.reverse_val_floor(val), reason)
}
}
IntView::Bool(lin) => lin.tighten_min(ctx, val, reason),
}
}
pub(crate) fn unify(
self,
ctx: &mut Model,
other: Resolved<View<IntVal>>,
) -> Result<(), Conflict<View<bool>>> {
use IntView::*;
let (idx, target) = match (self.0.0, other.0.0) {
(x, y) if x == y => return Ok(()),
(Bool(x), Bool(y)) => return x.unify(ctx, y),
(Const(x), Const(y)) if x != y => return Err(ctx.declare_conflict([])),
(Const(y), x) | (x, Const(y)) => {
let x = View::<IntVal>(x);
return x.fix(ctx, y, []);
}
(Linear(lin_x), Linear(lin_y)) => {
let can_define_x = lin_y.scale.get() % lin_x.scale.get() == 0
&& (lin_y.offset - lin_x.offset) % lin_x.scale.get() == 0;
let can_define_y = lin_x.scale.get() % lin_y.scale.get() == 0
&& (lin_x.offset - lin_y.offset) % lin_y.scale.get() == 0;
let (lin_x, lin_y) = if can_define_x && can_define_y && lin_x.var.0 > lin_y.var.0 {
(lin_x, lin_y)
} else if can_define_y {
(lin_y, lin_x)
} else if can_define_x {
(lin_x, lin_y)
} else {
ctx.post_constraint_internal(IntEq {
vars: [self.0, other.0],
});
return Ok(());
};
let scale = NonZero::new(lin_y.scale.get() / lin_x.scale.get()).unwrap();
let offset = (lin_y.offset - lin_x.offset) / lin_x.scale.get();
let target = View(Linear(LinearView::new(scale, offset, lin_y.var)));
(lin_x.var, target)
}
(Linear(lin), Bool(b)) | (Bool(b), Linear(lin)) => {
let lb = b.transform_val(0);
let ub = b.transform_val(1);
let contains_lb = lin.in_domain(ctx, lb);
let contains_ub = lin.in_domain(ctx, ub);
match (contains_lb, contains_ub) {
(false, false) => {
return Err(ctx.declare_conflict(|ctx: &mut Model| {
[
lin.lit(ctx, IntLitMeaning::NotEq(lb)),
lin.lit(ctx, IntLitMeaning::NotEq(ub)),
]
}));
}
(false, true) => {
let Some(val) = lin.try_reverse_val(ub) else {
unreachable!()
};
Resolved(lin.var).fix(ctx, val, [])?;
return b.var.require(ctx, |ctx: &mut Model| {
[lin.lit(ctx, IntLitMeaning::NotEq(lb))]
});
}
(true, false) => {
let Some(val) = lin.try_reverse_val(lb) else {
unreachable!()
};
Resolved(lin.var).fix(ctx, val, [])?;
return b.var.fix(ctx, false, |ctx: &mut Model| {
[lin.lit(ctx, IntLitMeaning::NotEq(ub))]
});
}
(true, true) => {
let Ok(IntLitMeaning::Eq(i_lb)) =
lin.reverse_meaning(IntLitMeaning::Eq(lb))
else {
unreachable!()
};
let Ok(IntLitMeaning::Eq(i_ub)) =
lin.reverse_meaning(IntLitMeaning::Eq(ub))
else {
unreachable!()
};
let target = View(Bool(LinearBoolView::new(
NonZero::new(i_ub - i_lb).unwrap(),
i_lb,
b.var,
)));
(lin.var, target)
}
}
}
};
Resolved(idx).unify_internal(ctx, target)
}
}
impl IntDecisionActions<Model> for Resolved<View<IntVal>> {
fn lit(&self, ctx: &mut Model, meaning: IntLitMeaning) -> View<bool> {
IntInspectionActions::try_lit(self, ctx, meaning).unwrap()
}
fn val_lit(&self, ctx: &mut Model) -> Option<View<bool>> {
let val = self.val(ctx)?;
Some(IntInspectionActions::try_lit(self, ctx, IntLitMeaning::Eq(val)).unwrap())
}
}
impl IntInspectionActions<Model> for Resolved<View<IntVal>> {
fn bounds(&self, ctx: &Model) -> (IntVal, IntVal) {
match self.0.0 {
IntView::Const(v) => (v, v),
IntView::Linear(lin) => {
LinearView::new(lin.scale, lin.offset, Resolved(lin.var)).bounds(ctx)
}
IntView::Bool(lin) => lin.bounds(ctx),
}
}
fn domain(&self, ctx: &Model) -> IntSet {
match self.0.0 {
IntView::Const(c) => (c..=c).into(),
IntView::Linear(lin) => {
LinearView::new(lin.scale, lin.offset, Resolved(lin.var)).domain(ctx)
}
IntView::Bool(lin) => lin.domain(ctx),
}
}
fn in_domain(&self, ctx: &Model, val: IntVal) -> bool {
match self.0.0 {
IntView::Const(v) => v == val,
IntView::Linear(lin) => {
LinearView::new(lin.scale, lin.offset, Resolved(lin.var)).in_domain(ctx, val)
}
IntView::Bool(lin) => lin.in_domain(ctx, val),
}
}
fn lit_meaning(&self, ctx: &Model, lit: View<bool>) -> Option<IntLitMeaning> {
match self.0.0 {
IntView::Const(_) => None,
IntView::Linear(lin) => {
LinearView::new(lin.scale, lin.offset, Resolved(lin.var)).lit_meaning(ctx, lit)
}
IntView::Bool(lin) => lin.lit_meaning(ctx, lit),
}
}
fn max(&self, ctx: &Model) -> IntVal {
match self.0.0 {
IntView::Const(v) => v,
IntView::Linear(lin) => {
LinearView::new(lin.scale, lin.offset, Resolved(lin.var)).max(ctx)
}
IntView::Bool(lin) => lin.max(ctx),
}
}
fn max_lit(&self, ctx: &Model) -> View<bool> {
match self.0.0 {
IntView::Const(_) => true.into(),
IntView::Linear(lin) => {
LinearView::new(lin.scale, lin.offset, Resolved(lin.var)).max_lit(ctx)
}
IntView::Bool(lin) => lin.max_lit(ctx),
}
}
fn min(&self, ctx: &Model) -> IntVal {
match self.0.0 {
IntView::Const(v) => v,
IntView::Linear(lin) => {
LinearView::new(lin.scale, lin.offset, Resolved(lin.var)).min(ctx)
}
IntView::Bool(lin) => lin.min(ctx),
}
}
fn min_lit(&self, ctx: &Model) -> View<bool> {
match self.0.0 {
IntView::Const(_) => true.into(),
IntView::Linear(lin) => {
LinearView::new(lin.scale, lin.offset, Resolved(lin.var)).min_lit(ctx)
}
IntView::Bool(lin) => lin.min_lit(ctx),
}
}
fn try_lit(&self, ctx: &Model, meaning: IntLitMeaning) -> Option<View<bool>> {
match self.0.0 {
IntView::Const(v) => Some(match meaning {
IntLitMeaning::Eq(i) => (v == i).into(),
IntLitMeaning::NotEq(i) => (v != i).into(),
IntLitMeaning::GreaterEq(i) => (v >= i).into(),
IntLitMeaning::Less(i) => (v < i).into(),
}),
IntView::Linear(lin) => {
LinearView::new(lin.scale, lin.offset, Resolved(lin.var)).try_lit(ctx, meaning)
}
IntView::Bool(lin) => lin.try_lit(ctx, meaning),
}
}
fn val(&self, ctx: &Model) -> Option<IntVal> {
match self.0.0 {
IntView::Const(v) => Some(v),
IntView::Linear(lin) => {
LinearView::new(lin.scale, lin.offset, Resolved(lin.var)).val(ctx)
}
IntView::Bool(lin) => lin.val(ctx),
}
}
}
impl View<IntVal> {
pub fn bounding_add<Ctx>(
self,
ctx: &mut Ctx,
rhs: IntVal,
) -> Result<View<IntVal>, Ctx::Conflict>
where
Ctx: PropagationActions + ReasoningContext + ?Sized,
View<IntVal>: IntPropagationActions<Ctx>,
{
if rhs.is_positive() {
let ub = self.max(ctx);
if ub.checked_add(rhs).is_none() {
if let Some(ub) = ub.checked_sub(rhs) {
self.tighten_max(ctx, ub, [])?;
} else {
return Err(ctx.declare_conflict(|ctx: &mut Ctx| {
[self.lit(ctx, IntLitMeaning::Less(IntVal::MIN))]
}));
}
}
} else {
let lb = self.min(ctx);
if lb.checked_add(rhs).is_none() {
if let Some(lb) = lb.checked_sub(rhs) {
self.tighten_min(ctx, lb, [])?;
} else {
return Err(ctx.declare_conflict([]));
}
}
}
Ok(self + rhs)
}
pub fn bounding_mul<Ctx>(
self,
ctx: &mut Ctx,
rhs: IntVal,
) -> Result<View<IntVal>, Ctx::Conflict>
where
Ctx: PropagationActions + ReasoningContext + ?Sized,
View<IntVal>: IntPropagationActions<Ctx>,
{
let (lb, ub) = self.bounds(ctx);
let (min, max) = if rhs.is_positive() {
(IntVal::MIN, IntVal::MAX)
} else {
(IntVal::MAX, IntVal::MIN)
};
if lb.checked_mul(rhs).is_none() {
if let Some(lb) = min.checked_div(rhs) {
self.tighten_min(ctx, lb, [])?;
} else {
return Err(ctx.declare_conflict([]));
}
}
if ub.checked_mul(rhs).is_none() {
if let Some(ub) = max.checked_div(rhs) {
self.tighten_max(ctx, ub, [])?;
} else {
return Err(ctx.declare_conflict(|ctx: &mut Ctx| {
[self.lit(ctx, IntLitMeaning::Less(IntVal::MIN))]
}));
}
}
Ok(self * rhs)
}
pub fn bounding_neg<Ctx>(self, ctx: &mut Ctx) -> Result<View<IntVal>, Ctx::Conflict>
where
Ctx: ReasoningContext + ?Sized,
View<IntVal>: IntPropagationActions<Ctx>,
{
if self.min(ctx) == IntVal::MIN {
self.tighten_min(ctx, -IntVal::MAX, [])?;
}
Ok(-self)
}
pub fn bounding_sub<Ctx>(
self,
ctx: &mut Ctx,
rhs: IntVal,
) -> Result<View<IntVal>, Ctx::Conflict>
where
Ctx: PropagationActions + ReasoningContext + ?Sized,
View<IntVal>: IntPropagationActions<Ctx>,
{
self.bounding_add(ctx, rhs.saturating_neg())
}
pub fn eq(&self, v: IntVal) -> View<bool> {
use IntView::*;
match self.0 {
Const(c) => (c == v).into(),
Linear(lin) => match lin.reverse_meaning(IntLitMeaning::Eq(v)) {
Ok(IntLitMeaning::Eq(val)) => View(BoolView::IntEq(lin.var, val)),
Err(b) => {
debug_assert!(!b);
false.into()
}
_ => unreachable!(),
},
Bool(lin) => match lin.reverse_meaning(IntLitMeaning::Eq(v)) {
Ok(IntLitMeaning::Eq(1)) => lin.var,
Ok(IntLitMeaning::Eq(0)) => !lin.var,
Ok(IntLitMeaning::Eq(_)) => false.into(),
Err(b) => {
debug_assert!(!b);
false.into()
}
_ => unreachable!(),
},
}
}
pub fn geq(&self, v: IntVal) -> View<bool> {
!self.lt(v)
}
pub fn gt(&self, v: IntVal) -> View<bool> {
self.geq(v + 1)
}
pub fn leq(&self, v: IntVal) -> View<bool> {
self.lt(v + 1)
}
pub fn lt(&self, v: IntVal) -> View<bool> {
use IntView::*;
match self.0 {
Const(c) => (c < v).into(),
Linear(lin) => match lin.reverse_meaning(IntLitMeaning::Less(v)) {
Ok(IntLitMeaning::GreaterEq(val)) => View(BoolView::IntGreaterEq(lin.var, val)),
Ok(IntLitMeaning::Less(val)) => View(BoolView::IntLess(lin.var, val)),
_ => unreachable!(),
},
Bool(lin) => match lin.reverse_meaning(IntLitMeaning::Less(v)) {
Ok(IntLitMeaning::GreaterEq(1)) => lin.var,
Ok(IntLitMeaning::GreaterEq(val)) if val > 1 => false.into(),
Ok(IntLitMeaning::GreaterEq(_)) => true.into(),
Ok(IntLitMeaning::Less(1)) => !lin.var,
Ok(IntLitMeaning::Less(val)) if val > 1 => true.into(),
Ok(IntLitMeaning::Less(_)) => false.into(),
_ => unreachable!(),
},
}
}
pub fn ne(&self, v: IntVal) -> View<bool> {
!self.eq(v)
}
}
impl Add<IntVal> for View<IntVal> {
type Output = Self;
fn add(self, rhs: IntVal) -> Self::Output {
use IntView::*;
if rhs == 0 {
return self;
}
View(match self.0 {
Const(v) => Const(v + rhs),
Linear(lin) => Linear(lin + rhs),
Bool(lin) => Bool(lin + rhs),
})
}
}
impl Add<View<IntVal>> for View<IntVal> {
type Output = IntLinearExp;
fn add(self, rhs: View<IntVal>) -> Self::Output {
IntLinearExp::from(self).add(rhs)
}
}
impl AddAssign<IntVal> for View<IntVal> {
fn add_assign(&mut self, rhs: IntVal) {
use IntView::*;
if rhs == 0 {
return;
}
match &mut self.0 {
Const(v) => *v += rhs,
Linear(lin) => *lin += rhs,
Bool(lin) => *lin += rhs,
};
}
}
impl From<Decision<IntVal>> for View<IntVal> {
fn from(decision: Decision<IntVal>) -> Self {
View(IntView::Linear(decision.into()))
}
}
impl From<View<bool>> for View<IntVal> {
fn from(value: View<bool>) -> Self {
match value.0 {
BoolView::Const(b) => (b as IntVal).into(),
_ => View(IntView::Bool(value.into())),
}
}
}
impl From<i64> for View<IntVal> {
fn from(value: i64) -> Self {
View(IntView::Const(value))
}
}
impl IntDecisionActions<Model> for View<IntVal> {
fn lit(&self, ctx: &mut Model, meaning: IntLitMeaning) -> View<bool> {
self.resolve_alias(ctx).lit(ctx, meaning)
}
fn val_lit(&self, ctx: &mut Model) -> Option<View<bool>> {
self.resolve_alias(ctx).val_lit(ctx)
}
}
impl IntExplanationActions<Model> for View<IntVal> {
fn lit_relaxed(&self, ctx: &Model, meaning: IntLitMeaning) -> (View<bool>, IntLitMeaning) {
(self.try_lit(ctx, meaning).unwrap(), meaning)
}
}
impl IntInspectionActions<Model> for View<IntVal> {
fn bounds(&self, ctx: &Model) -> (IntVal, IntVal) {
self.resolve_alias(ctx).bounds(ctx)
}
fn domain(&self, ctx: &Model) -> IntSet {
self.resolve_alias(ctx).domain(ctx)
}
fn in_domain(&self, ctx: &Model, val: IntVal) -> bool {
self.resolve_alias(ctx).in_domain(ctx, val)
}
fn lit_meaning(&self, ctx: &Model, lit: View<bool>) -> Option<IntLitMeaning> {
self.resolve_alias(ctx).lit_meaning(ctx, lit)
}
fn max(&self, ctx: &Model) -> IntVal {
self.resolve_alias(ctx).max(ctx)
}
fn max_lit(&self, ctx: &Model) -> View<bool> {
self.resolve_alias(ctx).max_lit(ctx)
}
fn min(&self, ctx: &Model) -> IntVal {
self.resolve_alias(ctx).min(ctx)
}
fn min_lit(&self, ctx: &Model) -> View<bool> {
self.resolve_alias(ctx).min_lit(ctx)
}
fn try_lit(&self, ctx: &Model, meaning: IntLitMeaning) -> Option<View<bool>> {
self.resolve_alias(ctx).try_lit(ctx, meaning)
}
fn val(&self, ctx: &Model) -> Option<IntVal> {
self.resolve_alias(ctx).val(ctx)
}
}
impl IntPropagationActions<Model> for View<IntVal> {
fn fix(
&self,
ctx: &mut Model,
val: IntVal,
reason: impl ReasonBuilder<Model>,
) -> Result<(), Conflict<View<bool>>> {
self.resolve_alias(ctx).fix(ctx, val, reason)
}
fn remove_val(
&self,
ctx: &mut Model,
val: IntVal,
reason: impl ReasonBuilder<Model>,
) -> Result<(), Conflict<View<bool>>> {
self.resolve_alias(ctx).remove_val(ctx, val, reason)
}
fn tighten_max(
&self,
ctx: &mut Model,
ub: IntVal,
reason: impl ReasonBuilder<Model>,
) -> Result<(), Conflict<View<bool>>> {
self.resolve_alias(ctx).tighten_max(ctx, ub, reason)
}
fn tighten_min(
&self,
ctx: &mut Model,
val: IntVal,
reason: impl ReasonBuilder<Model>,
) -> Result<(), Conflict<View<bool>>> {
self.resolve_alias(ctx).tighten_min(ctx, val, reason)
}
}
impl IntSimplificationActions<Model> for View<IntVal> {
fn exclude(
&self,
ctx: &mut Model,
values: &IntSet,
reason: impl ReasonBuilder<Model>,
) -> Result<(), Conflict<View<bool>>> {
self.resolve_alias(ctx).exclude(ctx, values, reason)
}
fn restrict_domain(
&self,
ctx: &mut Model,
values: &IntSet,
reason: impl ReasonBuilder<Model>,
) -> Result<(), Conflict<View<bool>>> {
self.resolve_alias(ctx).restrict_domain(ctx, values, reason)
}
fn unify(&self, ctx: &mut Model, other: impl Into<Self>) -> Result<(), Conflict<View<bool>>> {
let other = other.into().resolve_alias(ctx);
self.resolve_alias(ctx).unify(ctx, other)
}
}
impl Mul<IntVal> for View<IntVal> {
type Output = Self;
fn mul(mut self, rhs: IntVal) -> Self::Output {
self *= rhs;
self
}
}
impl Mul<NonZero<IntVal>> for View<IntVal> {
type Output = Self;
fn mul(mut self, rhs: NonZero<IntVal>) -> Self::Output {
self *= rhs;
self
}
}
impl MulAssign<IntVal> for View<IntVal> {
fn mul_assign(&mut self, rhs: IntVal) {
if let Some(rhs) = NonZero::new(rhs) {
*self *= rhs;
} else {
*self = 0.into();
}
}
}
impl MulAssign<NonZero<IntVal>> for View<IntVal> {
fn mul_assign(&mut self, rhs: NonZero<IntVal>) {
use IntView::*;
match &mut self.0 {
Const(v) => *v *= rhs.get(),
Linear(lin) => *lin *= rhs,
Bool(lin) => *lin *= rhs,
}
}
}
impl Neg for View<IntVal> {
type Output = Self;
fn neg(self) -> Self::Output {
use IntView::*;
View(match self.0 {
Const(v) => Const(-v),
Linear(lin) => Linear(-lin),
Bool(lin) => Bool(-lin),
})
}
}
impl Sub<IntVal> for View<IntVal> {
type Output = Self;
fn sub(self, rhs: IntVal) -> Self::Output {
self + -rhs
}
}
impl Sub<View<IntVal>> for View<IntVal> {
type Output = <Self as Add<View<IntVal>>>::Output;
fn sub(self, rhs: View<IntVal>) -> Self::Output {
self + -rhs
}
}
impl SubAssign<IntVal> for View<IntVal> {
fn sub_assign(&mut self, rhs: IntVal) {
*self += -rhs;
}
}