Skip to main content

pumpkin_propagators/propagators/
element.rs

1//! Contains the propagator for the [Element](https://sofdem.github.io/gccat/gccat/Celement.html)
2//! constraint.
3#![allow(clippy::double_parens, reason = "originates inside the bitfield macro")]
4
5use std::cell::RefCell;
6
7use bitfield_struct::bitfield;
8use pumpkin_checking::AtomicConstraint;
9use pumpkin_checking::CheckerVariable;
10use pumpkin_checking::Domain;
11use pumpkin_checking::InferenceChecker;
12use pumpkin_checking::Union;
13use pumpkin_core::conjunction;
14use pumpkin_core::declare_inference_label;
15use pumpkin_core::predicate;
16use pumpkin_core::predicates::Predicate;
17use pumpkin_core::proof::ConstraintTag;
18use pumpkin_core::proof::InferenceCode;
19use pumpkin_core::propagation::DomainEvents;
20use pumpkin_core::propagation::ExplanationContext;
21use pumpkin_core::propagation::InferenceCheckers;
22use pumpkin_core::propagation::LocalId;
23use pumpkin_core::propagation::Priority;
24use pumpkin_core::propagation::PropagationContext;
25use pumpkin_core::propagation::Propagator;
26use pumpkin_core::propagation::PropagatorConstructor;
27use pumpkin_core::propagation::PropagatorConstructorContext;
28use pumpkin_core::propagation::ReadDomains;
29use pumpkin_core::results::PropagationStatusCP;
30use pumpkin_core::variables::IntegerVariable;
31use pumpkin_core::variables::Reason;
32
33#[derive(Clone, Debug)]
34pub struct ElementArgs<VX, VI, VE> {
35    pub array: Box<[VX]>,
36    pub index: VI,
37    pub rhs: VE,
38    pub constraint_tag: ConstraintTag,
39}
40
41declare_inference_label!(Element);
42
43impl<VX, VI, VE> PropagatorConstructor for ElementArgs<VX, VI, VE>
44where
45    VX: IntegerVariable + 'static,
46    VI: IntegerVariable + 'static,
47    VE: IntegerVariable + 'static,
48{
49    type PropagatorImpl = ElementPropagator<VX, VI, VE>;
50
51    fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) {
52        checkers.add_inference_checker(
53            InferenceCode::new(self.constraint_tag, Element),
54            Box::new(ElementChecker::new(
55                self.array.clone(),
56                self.index.clone(),
57                self.rhs.clone(),
58            )),
59        );
60    }
61
62    fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl {
63        let ElementArgs {
64            array,
65            index,
66            rhs,
67            constraint_tag,
68        } = self;
69
70        for (i, x_i) in array.iter().enumerate() {
71            context.register(
72                x_i.clone(),
73                DomainEvents::ANY_INT,
74                LocalId::from(i as u32 + ID_X_OFFSET),
75            );
76        }
77
78        context.register(index.clone(), DomainEvents::ANY_INT, ID_INDEX);
79        context.register(rhs.clone(), DomainEvents::ANY_INT, ID_RHS);
80
81        let inference_code = InferenceCode::new(constraint_tag, Element);
82
83        ElementPropagator {
84            array,
85            index,
86            rhs,
87            inference_code,
88            rhs_reason_buffer: vec![],
89        }
90    }
91}
92
93const ID_INDEX: LocalId = LocalId::from(0);
94const ID_RHS: LocalId = LocalId::from(1);
95
96// local ids of array vars are shifted by ID_X_OFFSET
97const ID_X_OFFSET: u32 = 2;
98
99/// Arc-consistent propagator for constraint `element([x_1, \ldots, x_n], i, e)`, where `x_j` are
100///  variables, `i` is an integer variable, and `e` is a variable, which holds iff `x_i = e`
101///
102/// Note that this propagator is 0-indexed
103#[derive(Clone, Debug)]
104pub struct ElementPropagator<VX, VI, VE> {
105    array: Box<[VX]>,
106    index: VI,
107    rhs: VE,
108    inference_code: InferenceCode,
109
110    rhs_reason_buffer: Vec<Predicate>,
111}
112
113impl<VX, VI, VE> Propagator for ElementPropagator<VX, VI, VE>
114where
115    VX: IntegerVariable + 'static,
116    VI: IntegerVariable + 'static,
117    VE: IntegerVariable + 'static,
118{
119    fn priority(&self) -> Priority {
120        Priority::Low
121    }
122
123    fn name(&self) -> &str {
124        "Element"
125    }
126
127    fn propagate_from_scratch(&self, mut context: PropagationContext) -> PropagationStatusCP {
128        self.propagate_index_bounds_within_array(&mut context)?;
129
130        self.propagate_rhs_bounds_based_on_array(&mut context)?;
131
132        self.propagate_index_based_on_domain_intersection_with_rhs(&mut context)?;
133
134        if context.is_fixed(&self.index) {
135            let idx = context.lower_bound(&self.index);
136            self.propagate_equality(&mut context, idx)?;
137        }
138
139        Ok(())
140    }
141
142    fn lazy_explanation(&mut self, code: u64, context: ExplanationContext) -> &[Predicate] {
143        let payload = RightHandSideReason::from_bits(code);
144
145        self.rhs_reason_buffer.clear();
146        self.rhs_reason_buffer
147            .extend(self.array.iter().enumerate().map(|(idx, variable)| {
148                if context.contains_at_trail_position(
149                    &self.index,
150                    idx as i32,
151                    context.get_trail_position(),
152                ) {
153                    match payload.bound() {
154                        Bound::Lower => predicate![variable >= payload.value()],
155                        Bound::Upper => predicate![variable <= payload.value()],
156                    }
157                } else {
158                    predicate![self.index != idx as i32]
159                }
160            }));
161
162        &self.rhs_reason_buffer
163    }
164}
165
166impl<VX, VI, VE> ElementPropagator<VX, VI, VE>
167where
168    VX: IntegerVariable + 'static,
169    VI: IntegerVariable + 'static,
170    VE: IntegerVariable + 'static,
171{
172    /// Propagate the bounds of `self.index` to be in the range `[0, self.array.len())`.
173    fn propagate_index_bounds_within_array(
174        &self,
175        context: &mut PropagationContext<'_>,
176    ) -> PropagationStatusCP {
177        context.post(
178            predicate![self.index >= 0],
179            conjunction!(),
180            &self.inference_code,
181        )?;
182        context.post(
183            predicate![self.index <= self.array.len() as i32 - 1],
184            conjunction!(),
185            &self.inference_code,
186        )?;
187        Ok(())
188    }
189
190    /// The lower bound (resp. upper bound) of the right-hand side can be the minimum lower
191    /// bound (res. maximum upper bound) of the elements.
192    fn propagate_rhs_bounds_based_on_array(
193        &self,
194        context: &mut PropagationContext<'_>,
195    ) -> PropagationStatusCP {
196        let (rhs_lb, rhs_ub) = self
197            .array
198            .iter()
199            .enumerate()
200            .filter(|(idx, _)| context.contains(&self.index, *idx as i32))
201            .fold((i32::MAX, i32::MIN), |(rhs_lb, rhs_ub), (_, element)| {
202                (
203                    i32::min(rhs_lb, context.lower_bound(element)),
204                    i32::max(rhs_ub, context.upper_bound(element)),
205                )
206            });
207
208        context.post(
209            predicate![self.rhs >= rhs_lb],
210            Reason::DynamicLazy(
211                RightHandSideReason::new()
212                    .with_bound(Bound::Lower)
213                    .with_value(rhs_lb)
214                    .into_bits(),
215            ),
216            &self.inference_code,
217        )?;
218        context.post(
219            predicate![self.rhs <= rhs_ub],
220            Reason::DynamicLazy(
221                RightHandSideReason::new()
222                    .with_bound(Bound::Upper)
223                    .with_value(rhs_ub)
224                    .into_bits(),
225            ),
226            &self.inference_code,
227        )?;
228
229        Ok(())
230    }
231
232    /// Go through the array. For every element for which the domain does not intersect with the
233    /// right-hand side, remove it from index.
234    fn propagate_index_based_on_domain_intersection_with_rhs(
235        &self,
236        context: &mut PropagationContext<'_>,
237    ) -> PropagationStatusCP {
238        let rhs_lb = context.lower_bound(&self.rhs);
239        let rhs_ub = context.upper_bound(&self.rhs);
240        let mut to_remove = vec![];
241        for idx in context.iterate_domain(&self.index) {
242            let element = &self.array[idx as usize];
243
244            let element_ub = context.upper_bound(element);
245            let element_lb = context.lower_bound(element);
246
247            let reason = if rhs_lb > element_ub {
248                conjunction!([element <= rhs_lb - 1] & [self.rhs >= rhs_lb])
249            } else if rhs_ub < element_lb {
250                conjunction!([element >= rhs_ub + 1] & [self.rhs <= rhs_ub])
251            } else {
252                continue;
253            };
254
255            to_remove.push((idx, reason));
256        }
257
258        for (idx, reason) in to_remove.drain(..) {
259            context.post(predicate![self.index != idx], reason, &self.inference_code)?;
260        }
261
262        Ok(())
263    }
264
265    /// Propagate equality between lhs and rhs. This assumes the bounds of rhs have already been
266    /// tightened to the bounds of lhs, through a previous propagation rule.
267    fn propagate_equality(
268        &self,
269        context: &mut PropagationContext<'_>,
270        index: i32,
271    ) -> PropagationStatusCP {
272        let rhs_lb = context.lower_bound(&self.rhs);
273        let rhs_ub = context.upper_bound(&self.rhs);
274        let lhs = &self.array[index as usize];
275
276        context.post(
277            predicate![lhs >= rhs_lb],
278            conjunction!([self.rhs >= rhs_lb] & [self.index == index]),
279            &self.inference_code,
280        )?;
281        context.post(
282            predicate![lhs <= rhs_ub],
283            conjunction!([self.rhs <= rhs_ub] & [self.index == index]),
284            &self.inference_code,
285        )?;
286        Ok(())
287    }
288}
289
290#[derive(Clone, Copy, Debug, PartialEq, Eq)]
291#[repr(u8)]
292enum Bound {
293    Lower = 0,
294    Upper = 1,
295}
296
297impl Bound {
298    const fn into_bits(self) -> u8 {
299        self as _
300    }
301
302    const fn from_bits(value: u8) -> Self {
303        match value {
304            0 => Bound::Lower,
305            _ => Bound::Upper,
306        }
307    }
308}
309
310#[bitfield(u64)]
311struct RightHandSideReason {
312    #[bits(32, from = Bound::from_bits)]
313    bound: Bound,
314    value: i32,
315}
316
317#[derive(Clone, Debug)]
318pub struct ElementChecker<VX, VI, VE> {
319    array: Box<[VX]>,
320    index: VI,
321    rhs: VE,
322
323    union: RefCell<Union>,
324}
325
326impl<VX, VI, VE> ElementChecker<VX, VI, VE> {
327    /// Create a new [`ElementChecker`].
328    pub fn new(array: Box<[VX]>, index: VI, rhs: VE) -> Self {
329        ElementChecker {
330            array,
331            index,
332            rhs,
333            union: RefCell::new(Union::empty()),
334        }
335    }
336}
337
338impl<VX, VI, VE, Atomic> InferenceChecker<Atomic> for ElementChecker<VX, VI, VE>
339where
340    Atomic: AtomicConstraint,
341    VX: CheckerVariable<Atomic>,
342    VI: CheckerVariable<Atomic>,
343    VE: CheckerVariable<Atomic>,
344{
345    fn check(
346        &self,
347        state: pumpkin_checking::VariableState<Atomic>,
348        _: &[Atomic],
349        _: Option<&Atomic>,
350    ) -> bool {
351        self.union.borrow_mut().reset();
352
353        // A domain consistent checker for element does the following:
354        // 1. Determine the elements in the array whose index is in the domain of the index
355        //    variable.
356        // 2. Take the union of the domains of those elements.
357        // 3. Intersect that union with the domain on the right-hand side.
358        //
359        // The intersection should be empty for a conflict to exist.
360        let supported_elements: Vec<_> = self
361            .array
362            .iter()
363            .enumerate()
364            .filter(|(idx, _)| self.index.induced_domain_contains(&state, *idx as i32))
365            .map(|(_, element)| element)
366            .collect();
367
368        for element in supported_elements {
369            self.union.borrow_mut().add(&state, element);
370        }
371
372        assert!(
373            self.union.borrow().is_consistent(),
374            "at least one element has a non-empty domain or else variable state would be inconsistent"
375        );
376
377        // Compute `|union cap rhs| == 0`.
378        let intersection_lower_bound = self
379            .union
380            .borrow()
381            .lower_bound()
382            .max(self.rhs.induced_lower_bound(&state));
383        let intersection_upper_bound = self
384            .union
385            .borrow()
386            .upper_bound()
387            .min(self.rhs.induced_upper_bound(&state));
388        let holes = self
389            .union
390            .borrow()
391            .holes()
392            .chain(self.rhs.induced_holes(&state))
393            .collect();
394
395        let intersected_domain =
396            Domain::new(intersection_lower_bound, intersection_upper_bound, holes);
397
398        !intersected_domain.is_consistent()
399    }
400}
401
402#[allow(deprecated, reason = "Will be refactored")]
403#[cfg(test)]
404mod tests {
405    use pumpkin_checking::TestAtomic;
406    use pumpkin_checking::VariableState;
407    use pumpkin_core::TestSolver;
408
409    use super::*;
410
411    #[test]
412    fn elements_from_array_with_disjoint_domains_to_rhs_are_filtered_from_index() {
413        let mut solver = TestSolver::default();
414
415        let x_0 = solver.new_variable(4, 6);
416        let x_1 = solver.new_variable(2, 3);
417        let x_2 = solver.new_variable(7, 9);
418        let x_3 = solver.new_variable(14, 15);
419
420        let index = solver.new_variable(0, 3);
421        let rhs = solver.new_variable(6, 9);
422        let constraint_tag = solver.new_constraint_tag();
423
424        let _ = solver
425            .new_propagator(ElementArgs {
426                array: vec![x_0, x_1, x_2, x_3].into(),
427                index,
428                rhs,
429                constraint_tag,
430            })
431            .expect("no empty domains");
432
433        solver.assert_bounds(index, 0, 2);
434
435        assert_eq!(
436            solver.get_reason_int(predicate![index != 3]),
437            conjunction!([x_3 >= 10] & [rhs <= 9])
438        );
439
440        assert_eq!(
441            solver.get_reason_int(predicate![index != 1]),
442            conjunction!([x_1 <= 5] & [rhs >= 6])
443        );
444    }
445
446    #[test]
447    fn bounds_of_rhs_are_min_and_max_of_lower_and_upper_in_array() {
448        let mut solver = TestSolver::default();
449
450        let x_0 = solver.new_variable(3, 10);
451        let x_1 = solver.new_variable(2, 3);
452        let x_2 = solver.new_variable(7, 9);
453        let x_3 = solver.new_variable(14, 15);
454
455        let index = solver.new_variable(0, 3);
456        let rhs = solver.new_variable(0, 20);
457        let constraint_tag = solver.new_constraint_tag();
458
459        let _ = solver
460            .new_propagator(ElementArgs {
461                array: vec![x_0, x_1, x_2, x_3].into(),
462                index,
463                rhs,
464                constraint_tag,
465            })
466            .expect("no empty domains");
467
468        solver.assert_bounds(rhs, 2, 15);
469
470        assert_eq!(
471            solver.get_reason_int(predicate![rhs >= 2]),
472            conjunction!([x_0 >= 2] & [x_1 >= 2] & [x_2 >= 2] & [x_3 >= 2])
473        );
474
475        assert_eq!(
476            solver.get_reason_int(predicate![rhs <= 15]),
477            conjunction!([x_0 <= 15] & [x_1 <= 15] & [x_2 <= 15] & [x_3 <= 15])
478        );
479    }
480
481    #[test]
482    fn fixed_index_propagates_bounds_on_element() {
483        let mut solver = TestSolver::default();
484
485        let x_0 = solver.new_variable(3, 10);
486        let x_1 = solver.new_variable(0, 15);
487        let x_2 = solver.new_variable(7, 9);
488        let x_3 = solver.new_variable(14, 15);
489        let constraint_tag = solver.new_constraint_tag();
490
491        let index = solver.new_variable(1, 1);
492        let rhs = solver.new_variable(6, 9);
493
494        let _ = solver
495            .new_propagator(ElementArgs {
496                array: vec![x_0, x_1, x_2, x_3].into(),
497                index,
498                rhs,
499                constraint_tag,
500            })
501            .expect("no empty domains");
502
503        solver.assert_bounds(x_1, 6, 9);
504
505        assert_eq!(
506            solver.get_reason_int(predicate![x_1 >= 6]),
507            conjunction!([index == 1] & [rhs >= 6])
508        );
509
510        assert_eq!(
511            solver.get_reason_int(predicate![x_1 <= 9]),
512            conjunction!([index == 1] & [rhs <= 9])
513        );
514    }
515
516    #[test]
517    fn index_hole_propagates_bounds_on_rhs() {
518        let mut solver = TestSolver::default();
519
520        let x_0 = solver.new_variable(3, 10);
521        let x_1 = solver.new_variable(0, 15);
522        let x_2 = solver.new_variable(7, 9);
523        let x_3 = solver.new_variable(14, 15);
524        let constraint_tag = solver.new_constraint_tag();
525
526        let index = solver.new_variable(0, 3);
527        solver.remove(index, 1).expect("Value can be removed");
528
529        let rhs = solver.new_variable(-10, 30);
530
531        let _ = solver
532            .new_propagator(ElementArgs {
533                array: vec![x_0, x_1, x_2, x_3].into(),
534                index,
535                rhs,
536                constraint_tag,
537            })
538            .expect("no empty domains");
539
540        solver.assert_bounds(rhs, 3, 15);
541
542        assert_eq!(
543            solver.get_reason_int(predicate![rhs >= 3]),
544            conjunction!([x_0 >= 3] & [x_2 >= 3] & [x_3 >= 3] & [index != 1])
545        );
546
547        assert_eq!(
548            solver.get_reason_int(predicate![rhs <= 15]),
549            conjunction!([x_0 <= 15] & [x_2 <= 15] & [x_3 <= 15] & [index != 1])
550        );
551    }
552
553    #[test]
554    fn holes_outside_union_bounds_are_ignored() {
555        let premises = [
556            TestAtomic {
557                name: "x1",
558                comparison: pumpkin_checking::Comparison::GreaterEqual,
559                value: 4,
560            },
561            TestAtomic {
562                name: "x2",
563                comparison: pumpkin_checking::Comparison::NotEqual,
564                value: 2,
565            },
566        ];
567
568        let consequent = Some(TestAtomic {
569            name: "x4",
570            comparison: pumpkin_checking::Comparison::NotEqual,
571            value: 2,
572        });
573        let state = VariableState::prepare_for_conflict_check(premises, consequent)
574            .expect("no conflicting atomics");
575
576        let checker = ElementChecker::new(vec!["x1", "x2"].into(), "x3", "x4");
577
578        assert!(checker.check(state, &premises, consequent.as_ref()));
579    }
580}