1use pumpkin_checking::AtomicConstraint;
2use pumpkin_checking::CheckerVariable;
3use pumpkin_checking::InferenceChecker;
4use pumpkin_checking::IntExt;
5use pumpkin_core::asserts::pumpkin_assert_simple;
6use pumpkin_core::conjunction;
7use pumpkin_core::declare_inference_label;
8use pumpkin_core::predicate;
9use pumpkin_core::proof::ConstraintTag;
10use pumpkin_core::proof::InferenceCode;
11use pumpkin_core::propagation::DomainEvents;
12use pumpkin_core::propagation::InferenceCheckers;
13use pumpkin_core::propagation::LocalId;
14use pumpkin_core::propagation::Priority;
15use pumpkin_core::propagation::PropagationContext;
16use pumpkin_core::propagation::Propagator;
17use pumpkin_core::propagation::PropagatorConstructor;
18use pumpkin_core::propagation::PropagatorConstructorContext;
19use pumpkin_core::propagation::ReadDomains;
20use pumpkin_core::results::PropagationStatusCP;
21use pumpkin_core::variables::IntegerVariable;
22
23#[derive(Clone, Debug)]
25pub struct DivisionArgs<VA, VB, VC> {
26 pub numerator: VA,
27 pub denominator: VB,
28 pub rhs: VC,
29 pub constraint_tag: ConstraintTag,
30}
31
32const ID_NUMERATOR: LocalId = LocalId::from(0);
33const ID_DENOMINATOR: LocalId = LocalId::from(1);
34const ID_RHS: LocalId = LocalId::from(2);
35
36declare_inference_label!(Division);
37
38impl<VA, VB, VC> PropagatorConstructor for DivisionArgs<VA, VB, VC>
39where
40 VA: IntegerVariable + 'static,
41 VB: IntegerVariable + 'static,
42 VC: IntegerVariable + 'static,
43{
44 type PropagatorImpl = DivisionPropagator<VA, VB, VC>;
45
46 fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl {
47 let DivisionArgs {
48 numerator,
49 denominator,
50 rhs,
51 constraint_tag,
52 } = self;
53
54 pumpkin_assert_simple!(
55 !context.contains(&denominator, 0),
56 "Denominator cannot contain 0"
57 );
58
59 context.register(numerator.clone(), DomainEvents::BOUNDS, ID_NUMERATOR);
60 context.register(denominator.clone(), DomainEvents::BOUNDS, ID_DENOMINATOR);
61 context.register(rhs.clone(), DomainEvents::BOUNDS, ID_RHS);
62
63 let inference_code = InferenceCode::new(constraint_tag, Division);
64
65 DivisionPropagator {
66 numerator,
67 denominator,
68 rhs,
69 inference_code,
70 }
71 }
72
73 fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) {
74 checkers.add_inference_checker(
75 InferenceCode::new(self.constraint_tag, Division),
76 Box::new(IntegerDivisionChecker {
77 numerator: self.numerator.clone(),
78 denominator: self.denominator.clone(),
79 rhs: self.rhs.clone(),
80 }),
81 );
82 }
83}
84
85#[derive(Clone, Debug)]
92pub struct DivisionPropagator<VA, VB, VC> {
93 numerator: VA,
94 denominator: VB,
95 rhs: VC,
96 inference_code: InferenceCode,
97}
98
99impl<VA: 'static, VB: 'static, VC: 'static> Propagator for DivisionPropagator<VA, VB, VC>
100where
101 VA: IntegerVariable,
102 VB: IntegerVariable,
103 VC: IntegerVariable,
104{
105 fn priority(&self) -> Priority {
106 Priority::High
107 }
108
109 fn name(&self) -> &str {
110 "Division"
111 }
112
113 fn propagate_from_scratch(&self, context: PropagationContext) -> PropagationStatusCP {
114 perform_propagation(
115 context,
116 &self.numerator,
117 &self.denominator,
118 &self.rhs,
119 &self.inference_code,
120 )
121 }
122}
123
124fn perform_propagation<VA: IntegerVariable, VB: IntegerVariable, VC: IntegerVariable>(
125 mut context: PropagationContext,
126 numerator: &VA,
127 denominator: &VB,
128 rhs: &VC,
129 inference_code: &InferenceCode,
130) -> PropagationStatusCP {
131 if context.lower_bound(denominator) < 0 && context.upper_bound(denominator) > 0 {
132 return Ok(());
136 }
137
138 let mut negated_numerator = &numerator.scaled(-1);
139 let mut numerator = &numerator.scaled(1);
140
141 let mut negated_denominator = &denominator.scaled(-1);
142 let mut denominator = &denominator.scaled(1);
143
144 if context.upper_bound(denominator) < 0 {
145 std::mem::swap(&mut numerator, &mut negated_numerator);
148 std::mem::swap(&mut denominator, &mut negated_denominator);
149 }
150
151 let negated_rhs = &rhs.scaled(-1);
152
153 propagate_signs(&mut context, numerator, denominator, rhs, inference_code)?;
156
157 if context.upper_bound(numerator) >= 0 && context.upper_bound(rhs) >= 0 {
160 propagate_upper_bounds(&mut context, numerator, denominator, rhs, inference_code)?;
161 }
162
163 if context.upper_bound(negated_numerator) >= 0 && context.upper_bound(negated_rhs) >= 0 {
166 propagate_upper_bounds(
167 &mut context,
168 negated_numerator,
169 denominator,
170 negated_rhs,
171 inference_code,
172 )?;
173 }
174
175 if context.lower_bound(numerator) >= 0 && context.lower_bound(rhs) >= 0 {
179 propagate_positive_domains(&mut context, numerator, denominator, rhs, inference_code)?;
180 }
181
182 if context.lower_bound(negated_numerator) >= 0 && context.lower_bound(negated_rhs) >= 0 {
186 propagate_positive_domains(
187 &mut context,
188 negated_numerator,
189 denominator,
190 negated_rhs,
191 inference_code,
192 )?;
193 }
194
195 Ok(())
196}
197
198fn propagate_positive_domains<VA: IntegerVariable, VB: IntegerVariable, VC: IntegerVariable>(
209 context: &mut PropagationContext,
210 numerator: &VA,
211 denominator: &VB,
212 rhs: &VC,
213 inference_code: &InferenceCode,
214) -> PropagationStatusCP {
215 let rhs_min = context.lower_bound(rhs);
216 let rhs_max = context.upper_bound(rhs);
217 let numerator_min = context.lower_bound(numerator);
218 let numerator_max = context.upper_bound(numerator);
219 let denominator_min = context.lower_bound(denominator);
220 let denominator_max = context.upper_bound(denominator);
221
222 let new_min_rhs = numerator_min / denominator_max;
224 if rhs_min < new_min_rhs {
225 context.post(
226 predicate![rhs >= new_min_rhs],
227 conjunction!(
228 [numerator >= numerator_min]
229 & [denominator <= denominator_max]
230 & [denominator >= 1]
231 ),
232 inference_code,
233 )?;
234 }
235
236 let new_min_numerator = denominator_min * rhs_min;
241 if numerator_min < new_min_numerator {
242 context.post(
243 predicate![numerator >= new_min_numerator],
244 conjunction!([denominator >= denominator_min] & [rhs >= rhs_min]),
245 inference_code,
246 )?;
247 }
248
249 if rhs_min > 0 {
254 let new_max_denominator = numerator_max / rhs_min;
255 if denominator_max > new_max_denominator {
256 context.post(
257 predicate![denominator <= new_max_denominator],
258 conjunction!(
259 [numerator <= numerator_max]
260 & [numerator >= 0]
261 & [rhs >= rhs_min]
262 & [denominator >= 1]
263 ),
264 inference_code,
265 )?;
266 }
267 }
268
269 let new_min_denominator = {
270 let dividend = numerator_min + 1;
272 let positive_divisor = rhs_max + 1;
273
274 let result = dividend / positive_divisor;
275 let adjust = result * positive_divisor < dividend;
276 result + adjust as i32
277 };
278
279 if denominator_min < new_min_denominator {
280 context.post(
281 predicate![denominator >= new_min_denominator],
282 conjunction!(
283 [numerator >= numerator_min] & [rhs <= rhs_max] & [rhs >= 0] & [denominator >= 1]
284 ),
285 inference_code,
286 )?;
287 }
288
289 Ok(())
290}
291
292fn propagate_upper_bounds<VA: IntegerVariable, VB: IntegerVariable, VC: IntegerVariable>(
299 context: &mut PropagationContext,
300 numerator: &VA,
301 denominator: &VB,
302 rhs: &VC,
303 inference_code: &InferenceCode,
304) -> PropagationStatusCP {
305 let rhs_max = context.upper_bound(rhs);
306 let numerator_max = context.upper_bound(numerator);
307 let denominator_min = context.lower_bound(denominator);
308 let denominator_max = context.upper_bound(denominator);
309
310 let new_max_rhs = numerator_max / denominator_min;
313 if rhs_max > new_max_rhs {
314 context.post(
315 predicate![rhs <= new_max_rhs],
316 conjunction!([numerator <= numerator_max] & [denominator >= denominator_min]),
317 inference_code,
318 )?;
319 }
320
321 let new_max_numerator = (rhs_max + 1) * denominator_max - 1;
327 if numerator_max > new_max_numerator {
328 context.post(
329 predicate![numerator <= new_max_numerator],
330 conjunction!([denominator <= denominator_max] & [denominator >= 1] & [rhs <= rhs_max]),
331 inference_code,
332 )?;
333 }
334
335 Ok(())
336}
337
338fn propagate_signs<VA: IntegerVariable, VB: IntegerVariable, VC: IntegerVariable>(
345 context: &mut PropagationContext,
346 numerator: &VA,
347 denominator: &VB,
348 rhs: &VC,
349 inference_code: &InferenceCode,
350) -> PropagationStatusCP {
351 let rhs_min = context.lower_bound(rhs);
352 let rhs_max = context.upper_bound(rhs);
353 let numerator_min = context.lower_bound(numerator);
354 let numerator_max = context.upper_bound(numerator);
355
356 if numerator_min >= 0 && rhs_min < 0 {
359 context.post(
360 predicate![rhs >= 0],
361 conjunction!([numerator >= 0] & [denominator >= 1]),
362 inference_code,
363 )?;
364 }
365
366 if numerator_min <= 0 && rhs_min > 0 {
368 context.post(
369 predicate![numerator >= 1],
370 conjunction!([rhs >= 1] & [denominator >= 1]),
371 inference_code,
372 )?;
373 }
374
375 if numerator_max <= 0 && rhs_max > 0 {
377 context.post(
378 predicate![rhs <= 0],
379 conjunction!([numerator <= 0] & [denominator >= 1]),
380 inference_code,
381 )?;
382 }
383
384 if numerator_max >= 0 && rhs_max < 0 {
386 context.post(
387 predicate![numerator <= -1],
388 conjunction!([rhs <= -1] & [denominator >= 1]),
389 inference_code,
390 )?;
391 }
392
393 Ok(())
394}
395
396#[derive(Clone, Debug)]
397pub struct IntegerDivisionChecker<VA, VB, VC> {
398 pub numerator: VA,
399 pub denominator: VB,
400 pub rhs: VC,
401}
402
403impl<VA, VB, VC, Atomic> InferenceChecker<Atomic> for IntegerDivisionChecker<VA, VB, VC>
404where
405 Atomic: AtomicConstraint,
406 VA: CheckerVariable<Atomic>,
407 VB: CheckerVariable<Atomic>,
408 VC: CheckerVariable<Atomic>,
409{
410 fn check(
411 &self,
412 state: pumpkin_checking::VariableState<Atomic>,
413 _premises: &[Atomic],
414 _consequent: Option<&Atomic>,
415 ) -> bool {
416 let x1 = self.numerator.induced_lower_bound(&state);
422 let x2 = self.numerator.induced_upper_bound(&state);
423 let y1 = self.denominator.induced_lower_bound(&state);
424 let y2 = self.denominator.induced_upper_bound(&state);
425
426 assert!(
427 y2 < 0 || y1 > 0,
428 "Currentl, the checker does not contain inferences where the denominator spans 0"
429 );
430
431 let computed_c_lower: IntExt = *[
432 x1.div_ceil(y1),
433 x1.div_ceil(y2),
434 x2.div_ceil(y1),
435 x2.div_ceil(y2),
436 ]
437 .iter()
438 .flatten()
439 .min()
440 .expect("Expected at least one element to be defined");
441
442 let computed_c_upper: IntExt = *[
443 x1.div_floor(y1),
444 x1.div_floor(y2),
445 x2.div_floor(y1),
446 x2.div_floor(y2),
447 ]
448 .iter()
449 .flatten()
450 .min()
451 .expect("Expected at least one element to be defined");
452
453 let c_lower = self.rhs.induced_lower_bound(&state);
454 let c_upper = self.rhs.induced_upper_bound(&state);
455
456 computed_c_upper < c_lower || computed_c_lower > c_upper
457 }
458}
459
460#[allow(deprecated, reason = "Will be refactored")]
461#[cfg(test)]
462mod tests {
463 use pumpkin_core::TestSolver;
464
465 use super::*;
466
467 #[test]
468 fn detects_conflicts() {
469 let mut solver = TestSolver::default();
470 let numerator = solver.new_variable(1, 1);
471 let denominator = solver.new_variable(2, 2);
472 let rhs = solver.new_variable(2, 2);
473 let constraint_tag = solver.new_constraint_tag();
474
475 let propagator = solver.new_propagator(DivisionArgs {
476 numerator,
477 denominator,
478 rhs,
479 constraint_tag,
480 });
481
482 assert!(propagator.is_err());
483 }
484}