pumpkin_core/constraints/arithmetic/
equality.rs1use super::less_than_or_equals;
2use crate::constraints::Constraint;
3use crate::constraints::NegatableConstraint;
4use crate::proof::ConstraintTag;
5use crate::propagators::binary::BinaryEqualsPropagatorArgs;
6use crate::propagators::binary::BinaryNotEqualsPropagatorArgs;
7use crate::propagators::linear_not_equal::LinearNotEqualPropagatorArgs;
8use crate::propagators::ReifiedPropagatorArgs;
9use crate::variables::IntegerVariable;
10use crate::variables::Literal;
11use crate::variables::TransformableVariable;
12use crate::ConstraintOperationError;
13use crate::Solver;
14
15struct EqualConstraint<Var> {
16 terms: Box<[Var]>,
17 rhs: i32,
18 constraint_tag: ConstraintTag,
19}
20
21pub fn equals<Var: IntegerVariable + Clone + 'static>(
25 terms: impl Into<Box<[Var]>>,
26 rhs: i32,
27 constraint_tag: ConstraintTag,
28) -> impl NegatableConstraint {
29 EqualConstraint {
30 terms: terms.into(),
31 rhs,
32 constraint_tag,
33 }
34}
35
36pub fn binary_equals<Var: IntegerVariable + 'static>(
40 lhs: Var,
41 rhs: Var,
42 constraint_tag: ConstraintTag,
43) -> impl NegatableConstraint {
44 EqualConstraint {
45 terms: [lhs.scaled(1), rhs.scaled(-1)].into(),
46 rhs: 0,
47 constraint_tag,
48 }
49}
50
51struct NotEqualConstraint<Var> {
52 terms: Box<[Var]>,
53 rhs: i32,
54 constraint_tag: ConstraintTag,
55}
56
57pub fn not_equals<Var: IntegerVariable + Clone + 'static>(
61 terms: impl Into<Box<[Var]>>,
62 rhs: i32,
63 constraint_tag: ConstraintTag,
64) -> impl NegatableConstraint {
65 equals(terms, rhs, constraint_tag).negation()
66}
67
68pub fn binary_not_equals<Var: IntegerVariable + 'static>(
72 lhs: Var,
73 rhs: Var,
74 constraint_tag: ConstraintTag,
75) -> impl NegatableConstraint {
76 NotEqualConstraint {
77 terms: [lhs.scaled(1), rhs.scaled(-1)].into(),
78 rhs: 0,
79 constraint_tag,
80 }
81}
82
83impl<Var> Constraint for EqualConstraint<Var>
84where
85 Var: IntegerVariable + Clone + 'static,
86{
87 fn post(self, solver: &mut Solver) -> Result<(), ConstraintOperationError> {
88 if self.terms.len() == 2 && !solver.is_logging_full_proof() {
89 let _ = solver.add_propagator(BinaryEqualsPropagatorArgs {
90 a: self.terms[0].clone(),
91 b: self.terms[1].scaled(-1).offset(self.rhs),
92 constraint_tag: self.constraint_tag,
93 })?;
94 } else {
95 less_than_or_equals(self.terms.clone(), self.rhs, self.constraint_tag).post(solver)?;
96
97 let negated = self
98 .terms
99 .iter()
100 .map(|var| var.scaled(-1))
101 .collect::<Box<[_]>>();
102 less_than_or_equals(negated, -self.rhs, self.constraint_tag).post(solver)?;
103 }
104
105 Ok(())
106 }
107
108 fn implied_by(
109 self,
110 solver: &mut Solver,
111 reification_literal: Literal,
112 ) -> Result<(), ConstraintOperationError> {
113 if self.terms.len() == 2 {
114 let _ = solver.add_propagator(ReifiedPropagatorArgs {
115 propagator: BinaryEqualsPropagatorArgs {
116 a: self.terms[0].clone(),
117 b: self.terms[1].scaled(-1).offset(self.rhs),
118 constraint_tag: self.constraint_tag,
119 },
120 reification_literal,
121 })?;
122 } else {
123 less_than_or_equals(self.terms.clone(), self.rhs, self.constraint_tag)
124 .implied_by(solver, reification_literal)?;
125
126 let negated = self
127 .terms
128 .iter()
129 .map(|var| var.scaled(-1))
130 .collect::<Box<[_]>>();
131 less_than_or_equals(negated, -self.rhs, self.constraint_tag)
132 .implied_by(solver, reification_literal)?;
133 }
134
135 Ok(())
136 }
137}
138
139impl<Var> NegatableConstraint for EqualConstraint<Var>
140where
141 Var: IntegerVariable + Clone + 'static,
142{
143 type NegatedConstraint = NotEqualConstraint<Var>;
144
145 fn negation(&self) -> Self::NegatedConstraint {
146 NotEqualConstraint {
147 terms: self.terms.clone(),
148 rhs: self.rhs,
149 constraint_tag: self.constraint_tag,
150 }
151 }
152}
153
154impl<Var> Constraint for NotEqualConstraint<Var>
155where
156 Var: IntegerVariable + Clone + 'static,
157{
158 fn post(self, solver: &mut Solver) -> Result<(), ConstraintOperationError> {
159 let NotEqualConstraint {
160 terms,
161 rhs,
162 constraint_tag,
163 } = self;
164
165 if terms.len() == 2 {
166 let _ = solver.add_propagator(BinaryNotEqualsPropagatorArgs {
167 a: terms[0].clone(),
168 b: terms[1].scaled(-1).offset(self.rhs),
169 constraint_tag: self.constraint_tag,
170 })?;
171
172 Ok(())
173 } else {
174 LinearNotEqualPropagatorArgs {
175 terms: terms.into(),
176 rhs,
177 constraint_tag,
178 }
179 .post(solver)
180 }
181 }
182
183 fn implied_by(
184 self,
185 solver: &mut Solver,
186 reification_literal: Literal,
187 ) -> Result<(), ConstraintOperationError> {
188 let NotEqualConstraint {
189 terms,
190 rhs,
191 constraint_tag,
192 } = self;
193
194 if terms.len() == 2 {
195 let _ = solver.add_propagator(ReifiedPropagatorArgs {
196 propagator: BinaryNotEqualsPropagatorArgs {
197 a: terms[0].clone(),
198 b: terms[1].scaled(-1).offset(self.rhs),
199 constraint_tag: self.constraint_tag,
200 },
201 reification_literal,
202 })?;
203 Ok(())
204 } else {
205 LinearNotEqualPropagatorArgs {
206 terms: terms.into(),
207 rhs,
208 constraint_tag,
209 }
210 .implied_by(solver, reification_literal)
211 }
212 }
213}
214
215impl<Var> NegatableConstraint for NotEqualConstraint<Var>
216where
217 Var: IntegerVariable + Clone + 'static,
218{
219 type NegatedConstraint = EqualConstraint<Var>;
220
221 fn negation(&self) -> Self::NegatedConstraint {
222 EqualConstraint {
223 terms: self.terms.clone(),
224 rhs: self.rhs,
225 constraint_tag: self.constraint_tag,
226 }
227 }
228}