Skip to main content

pumpkin_core/engine/cp/
test_solver.rs

1//! This module exposes helpers that aid testing of CP propagators. The [`TestSolver`] allows
2//! setting up specific scenarios under which to test the various operations of a propagator.
3use std::fmt::Debug;
4
5use pumpkin_checking::InferenceChecker;
6
7use super::PropagatorQueue;
8use crate::containers::KeyGenerator;
9use crate::engine::EmptyDomain;
10use crate::engine::State;
11use crate::engine::predicates::predicate::Predicate;
12use crate::engine::variables::DomainId;
13use crate::engine::variables::IntegerVariable;
14use crate::engine::variables::Literal;
15use crate::options::LearningOptions;
16use crate::predicate;
17use crate::predicates::PropositionalConjunction;
18use crate::proof::ConstraintTag;
19use crate::proof::InferenceCode;
20use crate::propagation::EnqueueDecision;
21use crate::propagation::ExplanationContext;
22use crate::propagation::NotificationContext;
23use crate::propagation::PropagationContext;
24use crate::propagation::PropagatorConstructor;
25use crate::propagation::PropagatorId;
26use crate::propagators::nogoods::NogoodPropagator;
27use crate::propagators::nogoods::NogoodPropagatorConstructor;
28use crate::state::Conflict;
29use crate::state::PropagatorHandle;
30
31/// A container for CP variables, which can be used to test propagators.
32#[derive(Debug)]
33pub struct TestSolver {
34    pub state: State,
35    constraint_tags: KeyGenerator<ConstraintTag>,
36    pub nogood_handle: PropagatorHandle<NogoodPropagator>,
37}
38
39impl Default for TestSolver {
40    fn default() -> Self {
41        let mut state = State::default();
42        let handle = state.add_propagator(NogoodPropagatorConstructor::new(
43            0,
44            LearningOptions::default(),
45        ));
46        let mut solver = Self {
47            state,
48            constraint_tags: Default::default(),
49            nogood_handle: handle,
50        };
51        // We allocate space for the zero-th dummy variable at the root level of the assignments.
52        solver.state.notification_engine.grow();
53        solver
54    }
55}
56
57#[deprecated = "Will be replaced by the state API"]
58impl TestSolver {
59    pub fn accept_inferences_by(&mut self, inference_code: InferenceCode) {
60        #[derive(Debug, Clone, Copy)]
61        struct Checker;
62
63        impl InferenceChecker<Predicate> for Checker {
64            fn check(
65                &self,
66                _: pumpkin_checking::VariableState<Predicate>,
67                _: &[Predicate],
68                _: Option<&Predicate>,
69            ) -> bool {
70                true
71            }
72        }
73
74        self.state
75            .add_inference_checker(inference_code, Box::new(Checker));
76    }
77
78    pub fn new_variable(&mut self, lb: i32, ub: i32) -> DomainId {
79        self.state.new_interval_variable(lb, ub, None)
80    }
81
82    pub fn new_sparse_variable(&mut self, values: Vec<i32>) -> DomainId {
83        self.state.new_sparse_variable(values, None)
84    }
85
86    pub fn new_literal(&mut self) -> Literal {
87        let domain_id = self.new_variable(0, 1);
88        Literal::new(domain_id)
89    }
90
91    pub fn new_propagator<Constructor>(
92        &mut self,
93        constructor: Constructor,
94    ) -> Result<PropagatorId, Conflict>
95    where
96        Constructor: PropagatorConstructor,
97        Constructor::PropagatorImpl: 'static,
98    {
99        let handle = self.state.add_propagator(constructor);
100        self.state
101            .propagate_to_fixed_point()
102            .map(|_| handle.propagator_id())
103    }
104
105    pub fn contains<Var: IntegerVariable>(&self, var: Var, value: i32) -> bool {
106        var.contains(&self.state.assignments, value)
107    }
108
109    pub fn lower_bound(&self, var: DomainId) -> i32 {
110        self.state.assignments.get_lower_bound(var)
111    }
112
113    pub fn remove_and_notify(
114        &mut self,
115        propagator: PropagatorId,
116        var: DomainId,
117        value: i32,
118    ) -> EnqueueDecision {
119        let result = self.state.post(predicate!(var != value));
120        assert!(
121            result.is_ok(),
122            "The provided value to `increase_lower_bound` caused an empty domain, generally the propagator should not be notified of this change!"
123        );
124        let mut propagator_queue = PropagatorQueue::new(4);
125        #[allow(deprecated, reason = "Will be refactored in the future")]
126        self.state
127            .notification_engine
128            .notify_propagators_about_domain_events_test(
129                &mut self.state.assignments,
130                &mut self.state.trailed_values,
131                &mut self.state.propagators,
132                &mut propagator_queue,
133            );
134        if propagator_queue.is_propagator_enqueued(propagator) {
135            EnqueueDecision::Enqueue
136        } else {
137            EnqueueDecision::Skip
138        }
139    }
140
141    pub fn increase_lower_bound_and_notify(
142        &mut self,
143        propagator: PropagatorId,
144        _local_id: u32,
145        var: DomainId,
146        value: i32,
147    ) -> EnqueueDecision {
148        let result = self.state.post(predicate!(var >= value));
149        assert!(
150            result.is_ok(),
151            "The provided value to `increase_lower_bound` caused an empty domain, generally the propagator should not be notified of this change!"
152        );
153        let mut propagator_queue = PropagatorQueue::new(4);
154        #[allow(deprecated, reason = "Will be refactored in the future")]
155        self.state
156            .notification_engine
157            .notify_propagators_about_domain_events_test(
158                &mut self.state.assignments,
159                &mut self.state.trailed_values,
160                &mut self.state.propagators,
161                &mut propagator_queue,
162            );
163        if propagator_queue.is_propagator_enqueued(propagator) {
164            EnqueueDecision::Enqueue
165        } else {
166            EnqueueDecision::Skip
167        }
168    }
169
170    pub fn decrease_upper_bound_and_notify(
171        &mut self,
172        propagator: PropagatorId,
173        _local_id: u32,
174        var: DomainId,
175        value: i32,
176    ) -> EnqueueDecision {
177        let result = self.state.post(predicate!(var <= value));
178        assert!(
179            result.is_ok(),
180            "The provided value to `increase_lower_bound` caused an empty domain, generally the propagator should not be notified of this change!"
181        );
182        let mut propagator_queue = PropagatorQueue::new(4);
183        #[allow(deprecated, reason = "Will be refactored in the future")]
184        self.state
185            .notification_engine
186            .notify_propagators_about_domain_events_test(
187                &mut self.state.assignments,
188                &mut self.state.trailed_values,
189                &mut self.state.propagators,
190                &mut propagator_queue,
191            );
192        if propagator_queue.is_propagator_enqueued(propagator) {
193            EnqueueDecision::Enqueue
194        } else {
195            EnqueueDecision::Skip
196        }
197    }
198
199    pub fn is_literal_false(&self, literal: Literal) -> bool {
200        self.state
201            .assignments
202            .evaluate_predicate(literal.get_true_predicate())
203            .is_some_and(|truth_value| !truth_value)
204    }
205
206    pub fn upper_bound(&self, var: DomainId) -> i32 {
207        self.state.assignments.get_upper_bound(var)
208    }
209
210    pub fn remove(&mut self, var: DomainId, value: i32) -> Result<(), EmptyDomain> {
211        let _ = self.state.post(predicate!(var != value))?;
212
213        Ok(())
214    }
215
216    pub fn set_literal(&mut self, literal: Literal, truth_value: bool) -> Result<(), EmptyDomain> {
217        let _ = match truth_value {
218            true => self.state.assignments.post_predicate(
219                literal.get_true_predicate(),
220                None,
221                &mut self.state.notification_engine,
222            )?,
223            false => self.state.assignments.post_predicate(
224                (!literal).get_true_predicate(),
225                None,
226                &mut self.state.notification_engine,
227            )?,
228        };
229
230        Ok(())
231    }
232
233    pub fn propagate(&mut self, propagator: PropagatorId) -> Result<(), Conflict> {
234        let context = PropagationContext::new(
235            &mut self.state.trailed_values,
236            &mut self.state.assignments,
237            &mut self.state.reason_store,
238            &mut self.state.notification_engine,
239            propagator,
240        );
241        self.state.propagators[propagator].propagate(context)
242    }
243
244    pub fn propagate_until_fixed_point(
245        &mut self,
246        propagator: PropagatorId,
247    ) -> Result<(), Conflict> {
248        let mut num_trail_entries = self.state.assignments.num_trail_entries();
249        self.notify_propagator(propagator);
250        loop {
251            {
252                // Specify the life-times to be able to retrieve the trail entries
253                let context = PropagationContext::new(
254                    &mut self.state.trailed_values,
255                    &mut self.state.assignments,
256                    &mut self.state.reason_store,
257                    &mut self.state.notification_engine,
258                    propagator,
259                );
260                self.state.propagators[propagator].propagate(context)?;
261                self.notify_propagator(propagator);
262            }
263            if self.state.assignments.num_trail_entries() == num_trail_entries {
264                break;
265            }
266            num_trail_entries = self.state.assignments.num_trail_entries();
267        }
268        Ok(())
269    }
270
271    pub fn notify_propagator(&mut self, _propagator: PropagatorId) {
272        #[allow(deprecated, reason = "Will be refactored in the future")]
273        self.state
274            .notification_engine
275            .notify_propagators_about_domain_events_test(
276                &mut self.state.assignments,
277                &mut self.state.trailed_values,
278                &mut self.state.propagators,
279                &mut PropagatorQueue::new(4),
280            );
281    }
282
283    pub fn get_reason_int(&mut self, predicate: Predicate) -> PropositionalConjunction {
284        #[allow(deprecated, reason = "Will be refactored in the future")]
285        let reason_ref = self
286            .state
287            .assignments
288            .get_reason_for_predicate_brute_force(predicate);
289        let mut predicates = vec![];
290        let _ = self.state.reason_store.get_or_compute(
291            reason_ref,
292            ExplanationContext::without_working_nogood(
293                &self.state.assignments,
294                self.state
295                    .assignments
296                    .get_trail_position(&predicate)
297                    .unwrap(),
298                &mut self.state.notification_engine,
299            ),
300            &mut self.state.propagators,
301            &mut predicates,
302        );
303
304        PropositionalConjunction::from(predicates)
305    }
306
307    pub fn get_reason_bool(
308        &mut self,
309        literal: Literal,
310        truth_value: bool,
311    ) -> PropositionalConjunction {
312        let predicate = match truth_value {
313            true => literal.get_true_predicate(),
314            false => (!literal).get_true_predicate(),
315        };
316        self.get_reason_int(predicate)
317    }
318
319    pub fn assert_bounds(&self, var: DomainId, lb: i32, ub: i32) {
320        let actual_lb = self.lower_bound(var);
321        let actual_ub = self.upper_bound(var);
322
323        assert_eq!(
324            (lb, ub),
325            (actual_lb, actual_ub),
326            "The expected bounds [{lb}..{ub}] did not match the actual bounds [{actual_lb}..{actual_ub}]"
327        );
328    }
329
330    pub fn new_constraint_tag(&mut self) -> ConstraintTag {
331        self.constraint_tags.next_key()
332    }
333
334    pub fn new_checkpoint(&mut self) {
335        self.state.new_checkpoint();
336    }
337
338    pub fn synchronise(&mut self, level: usize) {
339        let _ = self
340            .state
341            .assignments
342            .synchronise(level, &mut self.state.notification_engine);
343        self.state.notification_engine.synchronise(
344            level,
345            &self.state.assignments,
346            &mut self.state.trailed_values,
347        );
348        self.state.trailed_values.synchronise(level);
349
350        for propagator in self.state.propagators.iter_propagators_mut() {
351            let mut context =
352                NotificationContext::new(&mut self.state.trailed_values, &self.state.assignments);
353
354            propagator.synchronise(context.reborrow());
355        }
356    }
357}