Skip to main content

pumpkin_constraints/constraints/arithmetic/
equality.rs

1use pumpkin_core::ConstraintOperationError;
2use pumpkin_core::Solver;
3use pumpkin_core::constraints::Constraint;
4use pumpkin_core::constraints::NegatableConstraint;
5use pumpkin_core::options::ReifiedPropagatorArgs;
6use pumpkin_core::proof::ConstraintTag;
7use pumpkin_core::variables::IntegerVariable;
8use pumpkin_core::variables::Literal;
9use pumpkin_core::variables::TransformableVariable;
10use pumpkin_propagators::arithmetic::BinaryEqualsPropagatorArgs;
11use pumpkin_propagators::arithmetic::BinaryNotEqualsPropagatorArgs;
12use pumpkin_propagators::arithmetic::LinearNotEqualPropagatorArgs;
13
14use super::less_than_or_equals;
15
16struct EqualConstraint<Var> {
17    terms: Box<[Var]>,
18    rhs: i32,
19    constraint_tag: ConstraintTag,
20}
21
22/// Creates the [`NegatableConstraint`] `∑ terms_i = rhs`.
23///
24/// Its negation is [`not_equals`].
25pub fn equals<Var: IntegerVariable + Clone + 'static>(
26    terms: impl Into<Box<[Var]>>,
27    rhs: i32,
28    constraint_tag: ConstraintTag,
29) -> impl NegatableConstraint {
30    EqualConstraint {
31        terms: terms.into(),
32        rhs,
33        constraint_tag,
34    }
35}
36
37/// Creates the [`NegatableConstraint`] `lhs = rhs`.
38///
39/// Its negation is [`binary_not_equals`].
40pub fn binary_equals<Var: IntegerVariable + 'static>(
41    lhs: Var,
42    rhs: Var,
43    constraint_tag: ConstraintTag,
44) -> impl NegatableConstraint {
45    EqualConstraint {
46        terms: [lhs.scaled(1), rhs.scaled(-1)].into(),
47        rhs: 0,
48        constraint_tag,
49    }
50}
51
52struct NotEqualConstraint<Var> {
53    terms: Box<[Var]>,
54    rhs: i32,
55    constraint_tag: ConstraintTag,
56}
57
58/// Create the [`NegatableConstraint`] `∑ terms_i != rhs`.
59///
60/// Its negation is [`equals`].
61pub fn not_equals<Var: IntegerVariable + Clone + 'static>(
62    terms: impl Into<Box<[Var]>>,
63    rhs: i32,
64    constraint_tag: ConstraintTag,
65) -> impl NegatableConstraint {
66    equals(terms, rhs, constraint_tag).negation()
67}
68
69/// Creates the [`NegatableConstraint`] `lhs != rhs`.
70///
71/// Its negation is [`binary_equals`].
72pub fn binary_not_equals<Var: IntegerVariable + 'static>(
73    lhs: Var,
74    rhs: Var,
75    constraint_tag: ConstraintTag,
76) -> impl NegatableConstraint {
77    NotEqualConstraint {
78        terms: [lhs.scaled(1), rhs.scaled(-1)].into(),
79        rhs: 0,
80        constraint_tag,
81    }
82}
83
84impl<Var> Constraint for EqualConstraint<Var>
85where
86    Var: IntegerVariable + Clone + 'static,
87{
88    fn post(self, solver: &mut Solver) -> Result<(), ConstraintOperationError> {
89        if self.terms.len() == 2 && !solver.is_logging_proof() {
90            let _ = solver.add_propagator(BinaryEqualsPropagatorArgs {
91                a: self.terms[0].clone(),
92                b: self.terms[1].scaled(-1).offset(self.rhs),
93                constraint_tag: self.constraint_tag,
94            })?;
95        } else {
96            less_than_or_equals(self.terms.clone(), self.rhs, self.constraint_tag).post(solver)?;
97
98            let negated = self
99                .terms
100                .iter()
101                .map(|var| var.scaled(-1))
102                .collect::<Box<[_]>>();
103            less_than_or_equals(negated, -self.rhs, self.constraint_tag).post(solver)?;
104        }
105
106        Ok(())
107    }
108
109    fn implied_by(
110        self,
111        solver: &mut Solver,
112        reification_literal: Literal,
113    ) -> Result<(), ConstraintOperationError> {
114        if self.terms.len() == 2 && !solver.is_logging_proof() {
115            let _ = solver.add_propagator(ReifiedPropagatorArgs {
116                propagator: BinaryEqualsPropagatorArgs {
117                    a: self.terms[0].clone(),
118                    b: self.terms[1].scaled(-1).offset(self.rhs),
119                    constraint_tag: self.constraint_tag,
120                },
121                reification_literal,
122            })?;
123        } else {
124            less_than_or_equals(self.terms.clone(), self.rhs, self.constraint_tag)
125                .implied_by(solver, reification_literal)?;
126
127            let negated = self
128                .terms
129                .iter()
130                .map(|var| var.scaled(-1))
131                .collect::<Box<[_]>>();
132            less_than_or_equals(negated, -self.rhs, self.constraint_tag)
133                .implied_by(solver, reification_literal)?;
134        }
135
136        Ok(())
137    }
138}
139
140impl<Var> NegatableConstraint for EqualConstraint<Var>
141where
142    Var: IntegerVariable + Clone + 'static,
143{
144    type NegatedConstraint = NotEqualConstraint<Var>;
145
146    fn negation(&self) -> Self::NegatedConstraint {
147        NotEqualConstraint {
148            terms: self.terms.clone(),
149            rhs: self.rhs,
150            constraint_tag: self.constraint_tag,
151        }
152    }
153}
154
155impl<Var> Constraint for NotEqualConstraint<Var>
156where
157    Var: IntegerVariable + Clone + 'static,
158{
159    fn post(self, solver: &mut Solver) -> Result<(), ConstraintOperationError> {
160        let NotEqualConstraint {
161            terms,
162            rhs,
163            constraint_tag,
164        } = self;
165
166        if terms.len() == 2 {
167            let _ = solver.add_propagator(BinaryNotEqualsPropagatorArgs {
168                a: terms[0].clone(),
169                b: terms[1].scaled(-1).offset(self.rhs),
170                constraint_tag: self.constraint_tag,
171            })?;
172
173            Ok(())
174        } else {
175            LinearNotEqualPropagatorArgs {
176                terms: terms.into(),
177                rhs,
178                constraint_tag,
179            }
180            .post(solver)
181        }
182    }
183
184    fn implied_by(
185        self,
186        solver: &mut Solver,
187        reification_literal: Literal,
188    ) -> Result<(), ConstraintOperationError> {
189        let NotEqualConstraint {
190            terms,
191            rhs,
192            constraint_tag,
193        } = self;
194
195        if terms.len() == 2 {
196            let _ = solver.add_propagator(ReifiedPropagatorArgs {
197                propagator: BinaryNotEqualsPropagatorArgs {
198                    a: terms[0].clone(),
199                    b: terms[1].scaled(-1).offset(self.rhs),
200                    constraint_tag: self.constraint_tag,
201                },
202                reification_literal,
203            })?;
204            Ok(())
205        } else {
206            LinearNotEqualPropagatorArgs {
207                terms: terms.into(),
208                rhs,
209                constraint_tag,
210            }
211            .implied_by(solver, reification_literal)
212        }
213    }
214}
215
216impl<Var> NegatableConstraint for NotEqualConstraint<Var>
217where
218    Var: IntegerVariable + Clone + 'static,
219{
220    type NegatedConstraint = EqualConstraint<Var>;
221
222    fn negation(&self) -> Self::NegatedConstraint {
223        EqualConstraint {
224            terms: self.terms.clone(),
225            rhs: self.rhs,
226            constraint_tag: self.constraint_tag,
227        }
228    }
229}