1#![allow(clippy::double_parens, reason = "originates inside the bitfield macro")]
4
5use std::cell::RefCell;
6
7use bitfield_struct::bitfield;
8use pumpkin_checking::AtomicConstraint;
9use pumpkin_checking::CheckerVariable;
10use pumpkin_checking::Domain;
11use pumpkin_checking::InferenceChecker;
12use pumpkin_checking::Union;
13use pumpkin_core::conjunction;
14use pumpkin_core::declare_inference_label;
15use pumpkin_core::predicate;
16use pumpkin_core::predicates::Predicate;
17use pumpkin_core::proof::ConstraintTag;
18use pumpkin_core::proof::InferenceCode;
19use pumpkin_core::propagation::DomainEvents;
20use pumpkin_core::propagation::ExplanationContext;
21use pumpkin_core::propagation::InferenceCheckers;
22use pumpkin_core::propagation::LocalId;
23use pumpkin_core::propagation::Priority;
24use pumpkin_core::propagation::PropagationContext;
25use pumpkin_core::propagation::Propagator;
26use pumpkin_core::propagation::PropagatorConstructor;
27use pumpkin_core::propagation::PropagatorConstructorContext;
28use pumpkin_core::propagation::ReadDomains;
29use pumpkin_core::results::PropagationStatusCP;
30use pumpkin_core::variables::IntegerVariable;
31use pumpkin_core::variables::Reason;
32
33#[derive(Clone, Debug)]
34pub struct ElementArgs<VX, VI, VE> {
35 pub array: Box<[VX]>,
36 pub index: VI,
37 pub rhs: VE,
38 pub constraint_tag: ConstraintTag,
39}
40
41declare_inference_label!(Element);
42
43impl<VX, VI, VE> PropagatorConstructor for ElementArgs<VX, VI, VE>
44where
45 VX: IntegerVariable + 'static,
46 VI: IntegerVariable + 'static,
47 VE: IntegerVariable + 'static,
48{
49 type PropagatorImpl = ElementPropagator<VX, VI, VE>;
50
51 fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) {
52 checkers.add_inference_checker(
53 InferenceCode::new(self.constraint_tag, Element),
54 Box::new(ElementChecker::new(
55 self.array.clone(),
56 self.index.clone(),
57 self.rhs.clone(),
58 )),
59 );
60 }
61
62 fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl {
63 let ElementArgs {
64 array,
65 index,
66 rhs,
67 constraint_tag,
68 } = self;
69
70 for (i, x_i) in array.iter().enumerate() {
71 context.register(
72 x_i.clone(),
73 DomainEvents::ANY_INT,
74 LocalId::from(i as u32 + ID_X_OFFSET),
75 );
76 }
77
78 context.register(index.clone(), DomainEvents::ANY_INT, ID_INDEX);
79 context.register(rhs.clone(), DomainEvents::ANY_INT, ID_RHS);
80
81 let inference_code = InferenceCode::new(constraint_tag, Element);
82
83 ElementPropagator {
84 array,
85 index,
86 rhs,
87 inference_code,
88 rhs_reason_buffer: vec![],
89 }
90 }
91}
92
93const ID_INDEX: LocalId = LocalId::from(0);
94const ID_RHS: LocalId = LocalId::from(1);
95
96const ID_X_OFFSET: u32 = 2;
98
99#[derive(Clone, Debug)]
104pub struct ElementPropagator<VX, VI, VE> {
105 array: Box<[VX]>,
106 index: VI,
107 rhs: VE,
108 inference_code: InferenceCode,
109
110 rhs_reason_buffer: Vec<Predicate>,
111}
112
113impl<VX, VI, VE> Propagator for ElementPropagator<VX, VI, VE>
114where
115 VX: IntegerVariable + 'static,
116 VI: IntegerVariable + 'static,
117 VE: IntegerVariable + 'static,
118{
119 fn priority(&self) -> Priority {
120 Priority::Low
121 }
122
123 fn name(&self) -> &str {
124 "Element"
125 }
126
127 fn propagate_from_scratch(&self, mut context: PropagationContext) -> PropagationStatusCP {
128 self.propagate_index_bounds_within_array(&mut context)?;
129
130 self.propagate_rhs_bounds_based_on_array(&mut context)?;
131
132 self.propagate_index_based_on_domain_intersection_with_rhs(&mut context)?;
133
134 if context.is_fixed(&self.index) {
135 let idx = context.lower_bound(&self.index);
136 self.propagate_equality(&mut context, idx)?;
137 }
138
139 Ok(())
140 }
141
142 fn lazy_explanation(&mut self, code: u64, context: ExplanationContext) -> &[Predicate] {
143 let payload = RightHandSideReason::from_bits(code);
144
145 self.rhs_reason_buffer.clear();
146 self.rhs_reason_buffer
147 .extend(self.array.iter().enumerate().map(|(idx, variable)| {
148 if context.contains_at_trail_position(
149 &self.index,
150 idx as i32,
151 context.get_trail_position(),
152 ) {
153 match payload.bound() {
154 Bound::Lower => predicate![variable >= payload.value()],
155 Bound::Upper => predicate![variable <= payload.value()],
156 }
157 } else {
158 predicate![self.index != idx as i32]
159 }
160 }));
161
162 &self.rhs_reason_buffer
163 }
164}
165
166impl<VX, VI, VE> ElementPropagator<VX, VI, VE>
167where
168 VX: IntegerVariable + 'static,
169 VI: IntegerVariable + 'static,
170 VE: IntegerVariable + 'static,
171{
172 fn propagate_index_bounds_within_array(
174 &self,
175 context: &mut PropagationContext<'_>,
176 ) -> PropagationStatusCP {
177 context.post(
178 predicate![self.index >= 0],
179 conjunction!(),
180 &self.inference_code,
181 )?;
182 context.post(
183 predicate![self.index <= self.array.len() as i32 - 1],
184 conjunction!(),
185 &self.inference_code,
186 )?;
187 Ok(())
188 }
189
190 fn propagate_rhs_bounds_based_on_array(
193 &self,
194 context: &mut PropagationContext<'_>,
195 ) -> PropagationStatusCP {
196 let (rhs_lb, rhs_ub) = self
197 .array
198 .iter()
199 .enumerate()
200 .filter(|(idx, _)| context.contains(&self.index, *idx as i32))
201 .fold((i32::MAX, i32::MIN), |(rhs_lb, rhs_ub), (_, element)| {
202 (
203 i32::min(rhs_lb, context.lower_bound(element)),
204 i32::max(rhs_ub, context.upper_bound(element)),
205 )
206 });
207
208 context.post(
209 predicate![self.rhs >= rhs_lb],
210 Reason::DynamicLazy(
211 RightHandSideReason::new()
212 .with_bound(Bound::Lower)
213 .with_value(rhs_lb)
214 .into_bits(),
215 ),
216 &self.inference_code,
217 )?;
218 context.post(
219 predicate![self.rhs <= rhs_ub],
220 Reason::DynamicLazy(
221 RightHandSideReason::new()
222 .with_bound(Bound::Upper)
223 .with_value(rhs_ub)
224 .into_bits(),
225 ),
226 &self.inference_code,
227 )?;
228
229 Ok(())
230 }
231
232 fn propagate_index_based_on_domain_intersection_with_rhs(
235 &self,
236 context: &mut PropagationContext<'_>,
237 ) -> PropagationStatusCP {
238 let rhs_lb = context.lower_bound(&self.rhs);
239 let rhs_ub = context.upper_bound(&self.rhs);
240 let mut to_remove = vec![];
241 for idx in context.iterate_domain(&self.index) {
242 let element = &self.array[idx as usize];
243
244 let element_ub = context.upper_bound(element);
245 let element_lb = context.lower_bound(element);
246
247 let reason = if rhs_lb > element_ub {
248 conjunction!([element <= rhs_lb - 1] & [self.rhs >= rhs_lb])
249 } else if rhs_ub < element_lb {
250 conjunction!([element >= rhs_ub + 1] & [self.rhs <= rhs_ub])
251 } else {
252 continue;
253 };
254
255 to_remove.push((idx, reason));
256 }
257
258 for (idx, reason) in to_remove.drain(..) {
259 context.post(predicate![self.index != idx], reason, &self.inference_code)?;
260 }
261
262 Ok(())
263 }
264
265 fn propagate_equality(
268 &self,
269 context: &mut PropagationContext<'_>,
270 index: i32,
271 ) -> PropagationStatusCP {
272 let rhs_lb = context.lower_bound(&self.rhs);
273 let rhs_ub = context.upper_bound(&self.rhs);
274 let lhs = &self.array[index as usize];
275
276 context.post(
277 predicate![lhs >= rhs_lb],
278 conjunction!([self.rhs >= rhs_lb] & [self.index == index]),
279 &self.inference_code,
280 )?;
281 context.post(
282 predicate![lhs <= rhs_ub],
283 conjunction!([self.rhs <= rhs_ub] & [self.index == index]),
284 &self.inference_code,
285 )?;
286 Ok(())
287 }
288}
289
290#[derive(Clone, Copy, Debug, PartialEq, Eq)]
291#[repr(u8)]
292enum Bound {
293 Lower = 0,
294 Upper = 1,
295}
296
297impl Bound {
298 const fn into_bits(self) -> u8 {
299 self as _
300 }
301
302 const fn from_bits(value: u8) -> Self {
303 match value {
304 0 => Bound::Lower,
305 _ => Bound::Upper,
306 }
307 }
308}
309
310#[bitfield(u64)]
311struct RightHandSideReason {
312 #[bits(32, from = Bound::from_bits)]
313 bound: Bound,
314 value: i32,
315}
316
317#[derive(Clone, Debug)]
318pub struct ElementChecker<VX, VI, VE> {
319 array: Box<[VX]>,
320 index: VI,
321 rhs: VE,
322
323 union: RefCell<Union>,
324}
325
326impl<VX, VI, VE> ElementChecker<VX, VI, VE> {
327 pub fn new(array: Box<[VX]>, index: VI, rhs: VE) -> Self {
329 ElementChecker {
330 array,
331 index,
332 rhs,
333 union: RefCell::new(Union::empty()),
334 }
335 }
336}
337
338impl<VX, VI, VE, Atomic> InferenceChecker<Atomic> for ElementChecker<VX, VI, VE>
339where
340 Atomic: AtomicConstraint,
341 VX: CheckerVariable<Atomic>,
342 VI: CheckerVariable<Atomic>,
343 VE: CheckerVariable<Atomic>,
344{
345 fn check(
346 &self,
347 state: pumpkin_checking::VariableState<Atomic>,
348 _: &[Atomic],
349 _: Option<&Atomic>,
350 ) -> bool {
351 self.union.borrow_mut().reset();
352
353 let supported_elements: Vec<_> = self
361 .array
362 .iter()
363 .enumerate()
364 .filter(|(idx, _)| self.index.induced_domain_contains(&state, *idx as i32))
365 .map(|(_, element)| element)
366 .collect();
367
368 for element in supported_elements {
369 self.union.borrow_mut().add(&state, element);
370 }
371
372 assert!(
373 self.union.borrow().is_consistent(),
374 "at least one element has a non-empty domain or else variable state would be inconsistent"
375 );
376
377 let intersection_lower_bound = self
379 .union
380 .borrow()
381 .lower_bound()
382 .max(self.rhs.induced_lower_bound(&state));
383 let intersection_upper_bound = self
384 .union
385 .borrow()
386 .upper_bound()
387 .min(self.rhs.induced_upper_bound(&state));
388 let holes = self
389 .union
390 .borrow()
391 .holes()
392 .chain(self.rhs.induced_holes(&state))
393 .collect();
394
395 let intersected_domain =
396 Domain::new(intersection_lower_bound, intersection_upper_bound, holes);
397
398 !intersected_domain.is_consistent()
399 }
400}
401
402#[allow(deprecated, reason = "Will be refactored")]
403#[cfg(test)]
404mod tests {
405 use pumpkin_checking::TestAtomic;
406 use pumpkin_checking::VariableState;
407 use pumpkin_core::TestSolver;
408
409 use super::*;
410
411 #[test]
412 fn elements_from_array_with_disjoint_domains_to_rhs_are_filtered_from_index() {
413 let mut solver = TestSolver::default();
414
415 let x_0 = solver.new_variable(4, 6);
416 let x_1 = solver.new_variable(2, 3);
417 let x_2 = solver.new_variable(7, 9);
418 let x_3 = solver.new_variable(14, 15);
419
420 let index = solver.new_variable(0, 3);
421 let rhs = solver.new_variable(6, 9);
422 let constraint_tag = solver.new_constraint_tag();
423
424 let _ = solver
425 .new_propagator(ElementArgs {
426 array: vec![x_0, x_1, x_2, x_3].into(),
427 index,
428 rhs,
429 constraint_tag,
430 })
431 .expect("no empty domains");
432
433 solver.assert_bounds(index, 0, 2);
434
435 assert_eq!(
436 solver.get_reason_int(predicate![index != 3]),
437 conjunction!([x_3 >= 10] & [rhs <= 9])
438 );
439
440 assert_eq!(
441 solver.get_reason_int(predicate![index != 1]),
442 conjunction!([x_1 <= 5] & [rhs >= 6])
443 );
444 }
445
446 #[test]
447 fn bounds_of_rhs_are_min_and_max_of_lower_and_upper_in_array() {
448 let mut solver = TestSolver::default();
449
450 let x_0 = solver.new_variable(3, 10);
451 let x_1 = solver.new_variable(2, 3);
452 let x_2 = solver.new_variable(7, 9);
453 let x_3 = solver.new_variable(14, 15);
454
455 let index = solver.new_variable(0, 3);
456 let rhs = solver.new_variable(0, 20);
457 let constraint_tag = solver.new_constraint_tag();
458
459 let _ = solver
460 .new_propagator(ElementArgs {
461 array: vec![x_0, x_1, x_2, x_3].into(),
462 index,
463 rhs,
464 constraint_tag,
465 })
466 .expect("no empty domains");
467
468 solver.assert_bounds(rhs, 2, 15);
469
470 assert_eq!(
471 solver.get_reason_int(predicate![rhs >= 2]),
472 conjunction!([x_0 >= 2] & [x_1 >= 2] & [x_2 >= 2] & [x_3 >= 2])
473 );
474
475 assert_eq!(
476 solver.get_reason_int(predicate![rhs <= 15]),
477 conjunction!([x_0 <= 15] & [x_1 <= 15] & [x_2 <= 15] & [x_3 <= 15])
478 );
479 }
480
481 #[test]
482 fn fixed_index_propagates_bounds_on_element() {
483 let mut solver = TestSolver::default();
484
485 let x_0 = solver.new_variable(3, 10);
486 let x_1 = solver.new_variable(0, 15);
487 let x_2 = solver.new_variable(7, 9);
488 let x_3 = solver.new_variable(14, 15);
489 let constraint_tag = solver.new_constraint_tag();
490
491 let index = solver.new_variable(1, 1);
492 let rhs = solver.new_variable(6, 9);
493
494 let _ = solver
495 .new_propagator(ElementArgs {
496 array: vec![x_0, x_1, x_2, x_3].into(),
497 index,
498 rhs,
499 constraint_tag,
500 })
501 .expect("no empty domains");
502
503 solver.assert_bounds(x_1, 6, 9);
504
505 assert_eq!(
506 solver.get_reason_int(predicate![x_1 >= 6]),
507 conjunction!([index == 1] & [rhs >= 6])
508 );
509
510 assert_eq!(
511 solver.get_reason_int(predicate![x_1 <= 9]),
512 conjunction!([index == 1] & [rhs <= 9])
513 );
514 }
515
516 #[test]
517 fn index_hole_propagates_bounds_on_rhs() {
518 let mut solver = TestSolver::default();
519
520 let x_0 = solver.new_variable(3, 10);
521 let x_1 = solver.new_variable(0, 15);
522 let x_2 = solver.new_variable(7, 9);
523 let x_3 = solver.new_variable(14, 15);
524 let constraint_tag = solver.new_constraint_tag();
525
526 let index = solver.new_variable(0, 3);
527 solver.remove(index, 1).expect("Value can be removed");
528
529 let rhs = solver.new_variable(-10, 30);
530
531 let _ = solver
532 .new_propagator(ElementArgs {
533 array: vec![x_0, x_1, x_2, x_3].into(),
534 index,
535 rhs,
536 constraint_tag,
537 })
538 .expect("no empty domains");
539
540 solver.assert_bounds(rhs, 3, 15);
541
542 assert_eq!(
543 solver.get_reason_int(predicate![rhs >= 3]),
544 conjunction!([x_0 >= 3] & [x_2 >= 3] & [x_3 >= 3] & [index != 1])
545 );
546
547 assert_eq!(
548 solver.get_reason_int(predicate![rhs <= 15]),
549 conjunction!([x_0 <= 15] & [x_2 <= 15] & [x_3 <= 15] & [index != 1])
550 );
551 }
552
553 #[test]
554 fn holes_outside_union_bounds_are_ignored() {
555 let premises = [
556 TestAtomic {
557 name: "x1",
558 comparison: pumpkin_checking::Comparison::GreaterEqual,
559 value: 4,
560 },
561 TestAtomic {
562 name: "x2",
563 comparison: pumpkin_checking::Comparison::NotEqual,
564 value: 2,
565 },
566 ];
567
568 let consequent = Some(TestAtomic {
569 name: "x4",
570 comparison: pumpkin_checking::Comparison::NotEqual,
571 value: 2,
572 });
573 let state = VariableState::prepare_for_conflict_check(premises, consequent)
574 .expect("no conflicting atomics");
575
576 let checker = ElementChecker::new(vec!["x1", "x2"].into(), "x3", "x4");
577
578 assert!(checker.check(state, &premises, consequent.as_ref()));
579 }
580}