pumpkin_constraints/constraints/arithmetic/
equality.rs1use pumpkin_core::ConstraintOperationError;
2use pumpkin_core::Solver;
3use pumpkin_core::constraints::Constraint;
4use pumpkin_core::constraints::NegatableConstraint;
5use pumpkin_core::options::ReifiedPropagatorArgs;
6use pumpkin_core::proof::ConstraintTag;
7use pumpkin_core::variables::IntegerVariable;
8use pumpkin_core::variables::Literal;
9use pumpkin_core::variables::TransformableVariable;
10use pumpkin_propagators::arithmetic::BinaryEqualsPropagatorArgs;
11use pumpkin_propagators::arithmetic::BinaryNotEqualsPropagatorArgs;
12use pumpkin_propagators::arithmetic::LinearNotEqualPropagatorArgs;
13
14use super::less_than_or_equals;
15
16struct EqualConstraint<Var> {
17 terms: Box<[Var]>,
18 rhs: i32,
19 constraint_tag: ConstraintTag,
20}
21
22pub fn equals<Var: IntegerVariable + Clone + 'static>(
26 terms: impl Into<Box<[Var]>>,
27 rhs: i32,
28 constraint_tag: ConstraintTag,
29) -> impl NegatableConstraint {
30 EqualConstraint {
31 terms: terms.into(),
32 rhs,
33 constraint_tag,
34 }
35}
36
37pub fn binary_equals<Var: IntegerVariable + 'static>(
41 lhs: Var,
42 rhs: Var,
43 constraint_tag: ConstraintTag,
44) -> impl NegatableConstraint {
45 EqualConstraint {
46 terms: [lhs.scaled(1), rhs.scaled(-1)].into(),
47 rhs: 0,
48 constraint_tag,
49 }
50}
51
52struct NotEqualConstraint<Var> {
53 terms: Box<[Var]>,
54 rhs: i32,
55 constraint_tag: ConstraintTag,
56}
57
58pub fn not_equals<Var: IntegerVariable + Clone + 'static>(
62 terms: impl Into<Box<[Var]>>,
63 rhs: i32,
64 constraint_tag: ConstraintTag,
65) -> impl NegatableConstraint {
66 equals(terms, rhs, constraint_tag).negation()
67}
68
69pub fn binary_not_equals<Var: IntegerVariable + 'static>(
73 lhs: Var,
74 rhs: Var,
75 constraint_tag: ConstraintTag,
76) -> impl NegatableConstraint {
77 NotEqualConstraint {
78 terms: [lhs.scaled(1), rhs.scaled(-1)].into(),
79 rhs: 0,
80 constraint_tag,
81 }
82}
83
84impl<Var> Constraint for EqualConstraint<Var>
85where
86 Var: IntegerVariable + Clone + 'static,
87{
88 fn post(self, solver: &mut Solver) -> Result<(), ConstraintOperationError> {
89 if self.terms.len() == 2 && !solver.is_logging_proof() {
90 let _ = solver.add_propagator(BinaryEqualsPropagatorArgs {
91 a: self.terms[0].clone(),
92 b: self.terms[1].scaled(-1).offset(self.rhs),
93 constraint_tag: self.constraint_tag,
94 })?;
95 } else {
96 less_than_or_equals(self.terms.clone(), self.rhs, self.constraint_tag).post(solver)?;
97
98 let negated = self
99 .terms
100 .iter()
101 .map(|var| var.scaled(-1))
102 .collect::<Box<[_]>>();
103 less_than_or_equals(negated, -self.rhs, self.constraint_tag).post(solver)?;
104 }
105
106 Ok(())
107 }
108
109 fn implied_by(
110 self,
111 solver: &mut Solver,
112 reification_literal: Literal,
113 ) -> Result<(), ConstraintOperationError> {
114 if self.terms.len() == 2 && !solver.is_logging_proof() {
115 let _ = solver.add_propagator(ReifiedPropagatorArgs {
116 propagator: BinaryEqualsPropagatorArgs {
117 a: self.terms[0].clone(),
118 b: self.terms[1].scaled(-1).offset(self.rhs),
119 constraint_tag: self.constraint_tag,
120 },
121 reification_literal,
122 })?;
123 } else {
124 less_than_or_equals(self.terms.clone(), self.rhs, self.constraint_tag)
125 .implied_by(solver, reification_literal)?;
126
127 let negated = self
128 .terms
129 .iter()
130 .map(|var| var.scaled(-1))
131 .collect::<Box<[_]>>();
132 less_than_or_equals(negated, -self.rhs, self.constraint_tag)
133 .implied_by(solver, reification_literal)?;
134 }
135
136 Ok(())
137 }
138}
139
140impl<Var> NegatableConstraint for EqualConstraint<Var>
141where
142 Var: IntegerVariable + Clone + 'static,
143{
144 type NegatedConstraint = NotEqualConstraint<Var>;
145
146 fn negation(&self) -> Self::NegatedConstraint {
147 NotEqualConstraint {
148 terms: self.terms.clone(),
149 rhs: self.rhs,
150 constraint_tag: self.constraint_tag,
151 }
152 }
153}
154
155impl<Var> Constraint for NotEqualConstraint<Var>
156where
157 Var: IntegerVariable + Clone + 'static,
158{
159 fn post(self, solver: &mut Solver) -> Result<(), ConstraintOperationError> {
160 let NotEqualConstraint {
161 terms,
162 rhs,
163 constraint_tag,
164 } = self;
165
166 if terms.len() == 2 {
167 let _ = solver.add_propagator(BinaryNotEqualsPropagatorArgs {
168 a: terms[0].clone(),
169 b: terms[1].scaled(-1).offset(self.rhs),
170 constraint_tag: self.constraint_tag,
171 })?;
172
173 Ok(())
174 } else {
175 LinearNotEqualPropagatorArgs {
176 terms: terms.into(),
177 rhs,
178 constraint_tag,
179 }
180 .post(solver)
181 }
182 }
183
184 fn implied_by(
185 self,
186 solver: &mut Solver,
187 reification_literal: Literal,
188 ) -> Result<(), ConstraintOperationError> {
189 let NotEqualConstraint {
190 terms,
191 rhs,
192 constraint_tag,
193 } = self;
194
195 if terms.len() == 2 {
196 let _ = solver.add_propagator(ReifiedPropagatorArgs {
197 propagator: BinaryNotEqualsPropagatorArgs {
198 a: terms[0].clone(),
199 b: terms[1].scaled(-1).offset(self.rhs),
200 constraint_tag: self.constraint_tag,
201 },
202 reification_literal,
203 })?;
204 Ok(())
205 } else {
206 LinearNotEqualPropagatorArgs {
207 terms: terms.into(),
208 rhs,
209 constraint_tag,
210 }
211 .implied_by(solver, reification_literal)
212 }
213 }
214}
215
216impl<Var> NegatableConstraint for NotEqualConstraint<Var>
217where
218 Var: IntegerVariable + Clone + 'static,
219{
220 type NegatedConstraint = EqualConstraint<Var>;
221
222 fn negation(&self) -> Self::NegatedConstraint {
223 EqualConstraint {
224 terms: self.terms.clone(),
225 rhs: self.rhs,
226 constraint_tag: self.constraint_tag,
227 }
228 }
229}