pumpkin_core/constraints/arithmetic/
equality.rs

1use super::less_than_or_equals;
2use crate::constraints::Constraint;
3use crate::constraints::NegatableConstraint;
4use crate::proof::ConstraintTag;
5use crate::propagators::binary::BinaryEqualsPropagatorArgs;
6use crate::propagators::binary::BinaryNotEqualsPropagatorArgs;
7use crate::propagators::linear_not_equal::LinearNotEqualPropagatorArgs;
8use crate::propagators::ReifiedPropagatorArgs;
9use crate::variables::IntegerVariable;
10use crate::variables::Literal;
11use crate::variables::TransformableVariable;
12use crate::ConstraintOperationError;
13use crate::Solver;
14
15struct EqualConstraint<Var> {
16    terms: Box<[Var]>,
17    rhs: i32,
18    constraint_tag: ConstraintTag,
19}
20
21/// Creates the [`NegatableConstraint`] `\sum terms_i = rhs`.
22///
23/// Its negation is [`not_equals`].
24pub fn equals<Var: IntegerVariable + Clone + 'static>(
25    terms: impl Into<Box<[Var]>>,
26    rhs: i32,
27    constraint_tag: ConstraintTag,
28) -> impl NegatableConstraint {
29    EqualConstraint {
30        terms: terms.into(),
31        rhs,
32        constraint_tag,
33    }
34}
35
36/// Creates the [`NegatableConstraint`] `lhs = rhs`.
37///
38/// Its negation is [`binary_not_equals`].
39pub fn binary_equals<Var: IntegerVariable + 'static>(
40    lhs: Var,
41    rhs: Var,
42    constraint_tag: ConstraintTag,
43) -> impl NegatableConstraint {
44    EqualConstraint {
45        terms: [lhs.scaled(1), rhs.scaled(-1)].into(),
46        rhs: 0,
47        constraint_tag,
48    }
49}
50
51struct NotEqualConstraint<Var> {
52    terms: Box<[Var]>,
53    rhs: i32,
54    constraint_tag: ConstraintTag,
55}
56
57/// Create the [`NegatableConstraint`] `\sum terms_i != rhs`.
58///
59/// Its negation is [`equals`].
60pub fn not_equals<Var: IntegerVariable + Clone + 'static>(
61    terms: impl Into<Box<[Var]>>,
62    rhs: i32,
63    constraint_tag: ConstraintTag,
64) -> impl NegatableConstraint {
65    equals(terms, rhs, constraint_tag).negation()
66}
67
68/// Creates the [`NegatableConstraint`] `lhs != rhs`.
69///
70/// Its negation is [`binary_equals`].
71pub fn binary_not_equals<Var: IntegerVariable + 'static>(
72    lhs: Var,
73    rhs: Var,
74    constraint_tag: ConstraintTag,
75) -> impl NegatableConstraint {
76    NotEqualConstraint {
77        terms: [lhs.scaled(1), rhs.scaled(-1)].into(),
78        rhs: 0,
79        constraint_tag,
80    }
81}
82
83impl<Var> Constraint for EqualConstraint<Var>
84where
85    Var: IntegerVariable + Clone + 'static,
86{
87    fn post(self, solver: &mut Solver) -> Result<(), ConstraintOperationError> {
88        if self.terms.len() == 2 && !solver.is_logging_full_proof() {
89            let _ = solver.add_propagator(BinaryEqualsPropagatorArgs {
90                a: self.terms[0].clone(),
91                b: self.terms[1].scaled(-1).offset(self.rhs),
92                constraint_tag: self.constraint_tag,
93            })?;
94        } else {
95            less_than_or_equals(self.terms.clone(), self.rhs, self.constraint_tag).post(solver)?;
96
97            let negated = self
98                .terms
99                .iter()
100                .map(|var| var.scaled(-1))
101                .collect::<Box<[_]>>();
102            less_than_or_equals(negated, -self.rhs, self.constraint_tag).post(solver)?;
103        }
104
105        Ok(())
106    }
107
108    fn implied_by(
109        self,
110        solver: &mut Solver,
111        reification_literal: Literal,
112    ) -> Result<(), ConstraintOperationError> {
113        if self.terms.len() == 2 {
114            let _ = solver.add_propagator(ReifiedPropagatorArgs {
115                propagator: BinaryEqualsPropagatorArgs {
116                    a: self.terms[0].clone(),
117                    b: self.terms[1].scaled(-1).offset(self.rhs),
118                    constraint_tag: self.constraint_tag,
119                },
120                reification_literal,
121            })?;
122        } else {
123            less_than_or_equals(self.terms.clone(), self.rhs, self.constraint_tag)
124                .implied_by(solver, reification_literal)?;
125
126            let negated = self
127                .terms
128                .iter()
129                .map(|var| var.scaled(-1))
130                .collect::<Box<[_]>>();
131            less_than_or_equals(negated, -self.rhs, self.constraint_tag)
132                .implied_by(solver, reification_literal)?;
133        }
134
135        Ok(())
136    }
137}
138
139impl<Var> NegatableConstraint for EqualConstraint<Var>
140where
141    Var: IntegerVariable + Clone + 'static,
142{
143    type NegatedConstraint = NotEqualConstraint<Var>;
144
145    fn negation(&self) -> Self::NegatedConstraint {
146        NotEqualConstraint {
147            terms: self.terms.clone(),
148            rhs: self.rhs,
149            constraint_tag: self.constraint_tag,
150        }
151    }
152}
153
154impl<Var> Constraint for NotEqualConstraint<Var>
155where
156    Var: IntegerVariable + Clone + 'static,
157{
158    fn post(self, solver: &mut Solver) -> Result<(), ConstraintOperationError> {
159        let NotEqualConstraint {
160            terms,
161            rhs,
162            constraint_tag,
163        } = self;
164
165        if terms.len() == 2 {
166            let _ = solver.add_propagator(BinaryNotEqualsPropagatorArgs {
167                a: terms[0].clone(),
168                b: terms[1].scaled(-1).offset(self.rhs),
169                constraint_tag: self.constraint_tag,
170            })?;
171
172            Ok(())
173        } else {
174            LinearNotEqualPropagatorArgs {
175                terms: terms.into(),
176                rhs,
177                constraint_tag,
178            }
179            .post(solver)
180        }
181    }
182
183    fn implied_by(
184        self,
185        solver: &mut Solver,
186        reification_literal: Literal,
187    ) -> Result<(), ConstraintOperationError> {
188        let NotEqualConstraint {
189            terms,
190            rhs,
191            constraint_tag,
192        } = self;
193
194        if terms.len() == 2 {
195            let _ = solver.add_propagator(ReifiedPropagatorArgs {
196                propagator: BinaryNotEqualsPropagatorArgs {
197                    a: terms[0].clone(),
198                    b: terms[1].scaled(-1).offset(self.rhs),
199                    constraint_tag: self.constraint_tag,
200                },
201                reification_literal,
202            })?;
203            Ok(())
204        } else {
205            LinearNotEqualPropagatorArgs {
206                terms: terms.into(),
207                rhs,
208                constraint_tag,
209            }
210            .implied_by(solver, reification_literal)
211        }
212    }
213}
214
215impl<Var> NegatableConstraint for NotEqualConstraint<Var>
216where
217    Var: IntegerVariable + Clone + 'static,
218{
219    type NegatedConstraint = EqualConstraint<Var>;
220
221    fn negation(&self) -> Self::NegatedConstraint {
222        EqualConstraint {
223            terms: self.terms.clone(),
224            rhs: self.rhs,
225            constraint_tag: self.constraint_tag,
226        }
227    }
228}