Skip to main content

pumpkin_propagators/propagators/arithmetic/
integer_division.rs

1use pumpkin_checking::AtomicConstraint;
2use pumpkin_checking::CheckerVariable;
3use pumpkin_checking::InferenceChecker;
4use pumpkin_checking::IntExt;
5use pumpkin_core::asserts::pumpkin_assert_simple;
6use pumpkin_core::conjunction;
7use pumpkin_core::declare_inference_label;
8use pumpkin_core::predicate;
9use pumpkin_core::proof::ConstraintTag;
10use pumpkin_core::proof::InferenceCode;
11use pumpkin_core::propagation::DomainEvents;
12use pumpkin_core::propagation::InferenceCheckers;
13use pumpkin_core::propagation::LocalId;
14use pumpkin_core::propagation::Priority;
15use pumpkin_core::propagation::PropagationContext;
16use pumpkin_core::propagation::Propagator;
17use pumpkin_core::propagation::PropagatorConstructor;
18use pumpkin_core::propagation::PropagatorConstructorContext;
19use pumpkin_core::propagation::ReadDomains;
20use pumpkin_core::results::PropagationStatusCP;
21use pumpkin_core::variables::IntegerVariable;
22
23/// The [`PropagatorConstructor`] for the [`DivisionPropagator`].
24#[derive(Clone, Debug)]
25pub struct DivisionArgs<VA, VB, VC> {
26    pub numerator: VA,
27    pub denominator: VB,
28    pub rhs: VC,
29    pub constraint_tag: ConstraintTag,
30}
31
32const ID_NUMERATOR: LocalId = LocalId::from(0);
33const ID_DENOMINATOR: LocalId = LocalId::from(1);
34const ID_RHS: LocalId = LocalId::from(2);
35
36declare_inference_label!(Division);
37
38impl<VA, VB, VC> PropagatorConstructor for DivisionArgs<VA, VB, VC>
39where
40    VA: IntegerVariable + 'static,
41    VB: IntegerVariable + 'static,
42    VC: IntegerVariable + 'static,
43{
44    type PropagatorImpl = DivisionPropagator<VA, VB, VC>;
45
46    fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl {
47        let DivisionArgs {
48            numerator,
49            denominator,
50            rhs,
51            constraint_tag,
52        } = self;
53
54        pumpkin_assert_simple!(
55            !context.contains(&denominator, 0),
56            "Denominator cannot contain 0"
57        );
58
59        context.register(numerator.clone(), DomainEvents::BOUNDS, ID_NUMERATOR);
60        context.register(denominator.clone(), DomainEvents::BOUNDS, ID_DENOMINATOR);
61        context.register(rhs.clone(), DomainEvents::BOUNDS, ID_RHS);
62
63        let inference_code = InferenceCode::new(constraint_tag, Division);
64
65        DivisionPropagator {
66            numerator,
67            denominator,
68            rhs,
69            inference_code,
70        }
71    }
72
73    fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) {
74        checkers.add_inference_checker(
75            InferenceCode::new(self.constraint_tag, Division),
76            Box::new(IntegerDivisionChecker {
77                numerator: self.numerator.clone(),
78                denominator: self.denominator.clone(),
79                rhs: self.rhs.clone(),
80            }),
81        );
82    }
83}
84
85/// A propagator for maintaining the constraint `numerator / denominator = rhs`; note that this
86/// propagator performs truncating division (i.e. rounding towards 0).
87///
88/// The propagator assumes that the `denominator` is a (non-zero) number.
89///
90/// The implementation is ported from [OR-tools](https://github.com/google/or-tools/blob/870edf6f7bff6b8ff0d267d936be7e331c5b8c2d/ortools/sat/integer_expr.cc#L1209C1-L1209C19).
91#[derive(Clone, Debug)]
92pub struct DivisionPropagator<VA, VB, VC> {
93    numerator: VA,
94    denominator: VB,
95    rhs: VC,
96    inference_code: InferenceCode,
97}
98
99impl<VA: 'static, VB: 'static, VC: 'static> Propagator for DivisionPropagator<VA, VB, VC>
100where
101    VA: IntegerVariable,
102    VB: IntegerVariable,
103    VC: IntegerVariable,
104{
105    fn priority(&self) -> Priority {
106        Priority::High
107    }
108
109    fn name(&self) -> &str {
110        "Division"
111    }
112
113    fn propagate_from_scratch(&self, context: PropagationContext) -> PropagationStatusCP {
114        perform_propagation(
115            context,
116            &self.numerator,
117            &self.denominator,
118            &self.rhs,
119            &self.inference_code,
120        )
121    }
122}
123
124fn perform_propagation<VA: IntegerVariable, VB: IntegerVariable, VC: IntegerVariable>(
125    mut context: PropagationContext,
126    numerator: &VA,
127    denominator: &VB,
128    rhs: &VC,
129    inference_code: &InferenceCode,
130) -> PropagationStatusCP {
131    if context.lower_bound(denominator) < 0 && context.upper_bound(denominator) > 0 {
132        // For now we don't do anything in this case, note that this will not lead to incorrect
133        // behaviour since any solution to this constraint will necessarily have to fix the
134        // denominator.
135        return Ok(());
136    }
137
138    let mut negated_numerator = &numerator.scaled(-1);
139    let mut numerator = &numerator.scaled(1);
140
141    let mut negated_denominator = &denominator.scaled(-1);
142    let mut denominator = &denominator.scaled(1);
143
144    if context.upper_bound(denominator) < 0 {
145        // If the denominator is negative then we swap the numerator with its negated version and we
146        // swap the denominator with its negated version.
147        std::mem::swap(&mut numerator, &mut negated_numerator);
148        std::mem::swap(&mut denominator, &mut negated_denominator);
149    }
150
151    let negated_rhs = &rhs.scaled(-1);
152
153    // We propagate the domains to their appropriate signs (e.g. if the numerator is negative and
154    // the denominator is positive then the rhs should also be negative)
155    propagate_signs(&mut context, numerator, denominator, rhs, inference_code)?;
156
157    // If the upper-bound of the numerator is positive and the upper-bound of the rhs is positive
158    // then we can simply update the upper-bounds
159    if context.upper_bound(numerator) >= 0 && context.upper_bound(rhs) >= 0 {
160        propagate_upper_bounds(&mut context, numerator, denominator, rhs, inference_code)?;
161    }
162
163    // If the lower-bound of the numerator is negative and the lower-bound of the rhs is negative
164    // then we negate these variables and update the upper-bounds
165    if context.upper_bound(negated_numerator) >= 0 && context.upper_bound(negated_rhs) >= 0 {
166        propagate_upper_bounds(
167            &mut context,
168            negated_numerator,
169            denominator,
170            negated_rhs,
171            inference_code,
172        )?;
173    }
174
175    // If the domain of the numerator is positive and the domain of the rhs is positive (and we know
176    // that our denominator is positive) then we can propagate based on the assumption that all the
177    // domains are positive
178    if context.lower_bound(numerator) >= 0 && context.lower_bound(rhs) >= 0 {
179        propagate_positive_domains(&mut context, numerator, denominator, rhs, inference_code)?;
180    }
181
182    // If the domain of the numerator is negative and the domain of the rhs is negative (and we know
183    // that our denominator is positive) then we propagate based on the views over the numerator and
184    // rhs
185    if context.lower_bound(negated_numerator) >= 0 && context.lower_bound(negated_rhs) >= 0 {
186        propagate_positive_domains(
187            &mut context,
188            negated_numerator,
189            denominator,
190            negated_rhs,
191            inference_code,
192        )?;
193    }
194
195    Ok(())
196}
197
198/// Propagates the domains of variables if all the domains are positive (if the variables are
199/// sign-fixed then we simply transform them to positive domains using [`AffineView`]s); it performs
200/// the following propagations:
201/// - The minimum value that division can take on is the smallest value that `numerator /
202///   denominator` can take on
203/// - The numerator is at least as large as the smallest value that `denominator * rhs` can take on
204/// - The value of the denominator is smaller than the largest value that `numerator / rhs` can take
205///   on
206/// - The denominator is at least as large as the ratio between the largest ceiled ratio between
207///   `numerator + 1` and `rhs + 1`
208fn propagate_positive_domains<VA: IntegerVariable, VB: IntegerVariable, VC: IntegerVariable>(
209    context: &mut PropagationContext,
210    numerator: &VA,
211    denominator: &VB,
212    rhs: &VC,
213    inference_code: &InferenceCode,
214) -> PropagationStatusCP {
215    let rhs_min = context.lower_bound(rhs);
216    let rhs_max = context.upper_bound(rhs);
217    let numerator_min = context.lower_bound(numerator);
218    let numerator_max = context.upper_bound(numerator);
219    let denominator_min = context.lower_bound(denominator);
220    let denominator_max = context.upper_bound(denominator);
221
222    // The new minimum value of the rhs is the minimum value that the division can take on
223    let new_min_rhs = numerator_min / denominator_max;
224    if rhs_min < new_min_rhs {
225        context.post(
226            predicate![rhs >= new_min_rhs],
227            conjunction!(
228                [numerator >= numerator_min]
229                    & [denominator <= denominator_max]
230                    & [denominator >= 1]
231            ),
232            inference_code,
233        )?;
234    }
235
236    // numerator / denominator >= rhs_min
237    // numerator >= rhs_min * denominator
238    // numerator >= rhs_min * denominator_min
239    // Note that we use rhs_min rather than new_min_rhs, this appears to be a heuristic
240    let new_min_numerator = denominator_min * rhs_min;
241    if numerator_min < new_min_numerator {
242        context.post(
243            predicate![numerator >= new_min_numerator],
244            conjunction!([denominator >= denominator_min] & [rhs >= rhs_min]),
245            inference_code,
246        )?;
247    }
248
249    // numerator / denominator >= rhs_min
250    // numerator >= rhs_min * denominator
251    // If rhs_min == 0 -> no propagations
252    // Otherwise, denominator <= numerator / rhs_min & denominator <= numerator_max / rhs_min
253    if rhs_min > 0 {
254        let new_max_denominator = numerator_max / rhs_min;
255        if denominator_max > new_max_denominator {
256            context.post(
257                predicate![denominator <= new_max_denominator],
258                conjunction!(
259                    [numerator <= numerator_max]
260                        & [numerator >= 0]
261                        & [rhs >= rhs_min]
262                        & [denominator >= 1]
263                ),
264                inference_code,
265            )?;
266        }
267    }
268
269    let new_min_denominator = {
270        // Called the CeilRatio in OR-tools
271        let dividend = numerator_min + 1;
272        let positive_divisor = rhs_max + 1;
273
274        let result = dividend / positive_divisor;
275        let adjust = result * positive_divisor < dividend;
276        result + adjust as i32
277    };
278
279    if denominator_min < new_min_denominator {
280        context.post(
281            predicate![denominator >= new_min_denominator],
282            conjunction!(
283                [numerator >= numerator_min] & [rhs <= rhs_max] & [rhs >= 0] & [denominator >= 1]
284            ),
285            inference_code,
286        )?;
287    }
288
289    Ok(())
290}
291
292/// Propagates the upper-bounds of the right-hand side and the numerator, it performs the following
293/// propagations
294/// - The maximum value of the right-hand side can only be as large as the largest value that
295///   `numerator / denominator` can take on
296/// - The maximum value of the numerator is smaller than `(ub(rhs) + 1) * denominator - 1`, note
297///   that this might not be the most constrictive bound
298fn propagate_upper_bounds<VA: IntegerVariable, VB: IntegerVariable, VC: IntegerVariable>(
299    context: &mut PropagationContext,
300    numerator: &VA,
301    denominator: &VB,
302    rhs: &VC,
303    inference_code: &InferenceCode,
304) -> PropagationStatusCP {
305    let rhs_max = context.upper_bound(rhs);
306    let numerator_max = context.upper_bound(numerator);
307    let denominator_min = context.lower_bound(denominator);
308    let denominator_max = context.upper_bound(denominator);
309
310    // The new maximum value of the rhs is the maximum value that the division can take on (note
311    // that numerator_max is positive and denominator_min is also positive)
312    let new_max_rhs = numerator_max / denominator_min;
313    if rhs_max > new_max_rhs {
314        context.post(
315            predicate![rhs <= new_max_rhs],
316            conjunction!([numerator <= numerator_max] & [denominator >= denominator_min]),
317            inference_code,
318        )?;
319    }
320
321    // numerator / denominator <= rhs.max
322    // numerator < (rhs.max + 1) * denominator
323    // numerator + 1 <= (rhs.max + 1) * denominator.max
324    // numerator <= (rhs.max + 1) * denominator.max - 1
325    // Note that we use rhs_max here rather than the new upper-bound, this appears to be a heuristic
326    let new_max_numerator = (rhs_max + 1) * denominator_max - 1;
327    if numerator_max > new_max_numerator {
328        context.post(
329            predicate![numerator <= new_max_numerator],
330            conjunction!([denominator <= denominator_max] & [denominator >= 1] & [rhs <= rhs_max]),
331            inference_code,
332        )?;
333    }
334
335    Ok(())
336}
337
338/// Propagates the signs of the variables, more specifically, it performs the following propagations
339/// (assuming that the denominator is always > 0):
340/// - If the numerator is non-negative then the right-hand side must be non-negative as well
341/// - If the right-hand side is positive then the numerator must be positive as well
342/// - If the numerator is non-positive then the right-hand side must be non-positive as well
343/// - If the right-hand is negative then the numerator must be negative as well
344fn propagate_signs<VA: IntegerVariable, VB: IntegerVariable, VC: IntegerVariable>(
345    context: &mut PropagationContext,
346    numerator: &VA,
347    denominator: &VB,
348    rhs: &VC,
349    inference_code: &InferenceCode,
350) -> PropagationStatusCP {
351    let rhs_min = context.lower_bound(rhs);
352    let rhs_max = context.upper_bound(rhs);
353    let numerator_min = context.lower_bound(numerator);
354    let numerator_max = context.upper_bound(numerator);
355
356    // First we propagate the signs
357    // If the numerator >= 0 (and we know that denominator > 0) then the rhs must be >= 0
358    if numerator_min >= 0 && rhs_min < 0 {
359        context.post(
360            predicate![rhs >= 0],
361            conjunction!([numerator >= 0] & [denominator >= 1]),
362            inference_code,
363        )?;
364    }
365
366    // If rhs > 0 (and we know that denominator > 0) then the numerator must be > 0
367    if numerator_min <= 0 && rhs_min > 0 {
368        context.post(
369            predicate![numerator >= 1],
370            conjunction!([rhs >= 1] & [denominator >= 1]),
371            inference_code,
372        )?;
373    }
374
375    // If numerator <= 0 (and we know that denominator > 0) then the rhs must be <= 0
376    if numerator_max <= 0 && rhs_max > 0 {
377        context.post(
378            predicate![rhs <= 0],
379            conjunction!([numerator <= 0] & [denominator >= 1]),
380            inference_code,
381        )?;
382    }
383
384    // If the rhs < 0 (and we know that denominator > 0) then the numerator must be < 0
385    if numerator_max >= 0 && rhs_max < 0 {
386        context.post(
387            predicate![numerator <= -1],
388            conjunction!([rhs <= -1] & [denominator >= 1]),
389            inference_code,
390        )?;
391    }
392
393    Ok(())
394}
395
396#[derive(Clone, Debug)]
397pub struct IntegerDivisionChecker<VA, VB, VC> {
398    pub numerator: VA,
399    pub denominator: VB,
400    pub rhs: VC,
401}
402
403impl<VA, VB, VC, Atomic> InferenceChecker<Atomic> for IntegerDivisionChecker<VA, VB, VC>
404where
405    Atomic: AtomicConstraint,
406    VA: CheckerVariable<Atomic>,
407    VB: CheckerVariable<Atomic>,
408    VC: CheckerVariable<Atomic>,
409{
410    fn check(
411        &self,
412        state: pumpkin_checking::VariableState<Atomic>,
413        _premises: &[Atomic],
414        _consequent: Option<&Atomic>,
415    ) -> bool {
416        // We apply interval arithmetic to determine that the computed interval `a div b`
417        // does not intersect with the domain of `c`.
418        //
419        // See https://en.wikipedia.org/wiki/Interval_arithmetic#Interval_operators.
420
421        let x1 = self.numerator.induced_lower_bound(&state);
422        let x2 = self.numerator.induced_upper_bound(&state);
423        let y1 = self.denominator.induced_lower_bound(&state);
424        let y2 = self.denominator.induced_upper_bound(&state);
425
426        assert!(
427            y2 < 0 || y1 > 0,
428            "Currentl, the checker does not contain inferences where the denominator spans 0"
429        );
430
431        let computed_c_lower: IntExt = *[
432            x1.div_ceil(y1),
433            x1.div_ceil(y2),
434            x2.div_ceil(y1),
435            x2.div_ceil(y2),
436        ]
437        .iter()
438        .flatten()
439        .min()
440        .expect("Expected at least one element to be defined");
441
442        let computed_c_upper: IntExt = *[
443            x1.div_floor(y1),
444            x1.div_floor(y2),
445            x2.div_floor(y1),
446            x2.div_floor(y2),
447        ]
448        .iter()
449        .flatten()
450        .min()
451        .expect("Expected at least one element to be defined");
452
453        let c_lower = self.rhs.induced_lower_bound(&state);
454        let c_upper = self.rhs.induced_upper_bound(&state);
455
456        computed_c_upper < c_lower || computed_c_lower > c_upper
457    }
458}
459
460#[allow(deprecated, reason = "Will be refactored")]
461#[cfg(test)]
462mod tests {
463    use pumpkin_core::TestSolver;
464
465    use super::*;
466
467    #[test]
468    fn detects_conflicts() {
469        let mut solver = TestSolver::default();
470        let numerator = solver.new_variable(1, 1);
471        let denominator = solver.new_variable(2, 2);
472        let rhs = solver.new_variable(2, 2);
473        let constraint_tag = solver.new_constraint_tag();
474
475        let propagator = solver.new_propagator(DivisionArgs {
476            numerator,
477            denominator,
478            rhs,
479            constraint_tag,
480        });
481
482        assert!(propagator.is_err());
483    }
484}