Skip to main content

pumpkin_core/engine/
state.rs

1use std::sync::Arc;
2
3use pumpkin_checking::BoxedChecker;
4use pumpkin_checking::InferenceChecker;
5#[cfg(feature = "check-propagations")]
6use pumpkin_checking::VariableState;
7
8use crate::basic_types::PropagatorConflict;
9use crate::containers::HashMap;
10use crate::containers::KeyGenerator;
11use crate::create_statistics_struct;
12use crate::engine::Assignments;
13use crate::engine::ConstraintProgrammingTrailEntry;
14use crate::engine::DebugHelper;
15use crate::engine::EmptyDomain;
16use crate::engine::PropagatorQueue;
17#[cfg(test)]
18use crate::engine::Reason;
19use crate::engine::TrailedValues;
20use crate::engine::VariableNames;
21use crate::engine::notifications::NotificationEngine;
22use crate::engine::reason::ReasonRef;
23use crate::engine::reason::ReasonStore;
24use crate::predicate;
25use crate::predicates::Predicate;
26use crate::predicates::PredicateType;
27use crate::proof::ConstraintTag;
28use crate::proof::InferenceCode;
29#[cfg(doc)]
30use crate::proof::ProofLog;
31use crate::propagation::CurrentNogood;
32use crate::propagation::Domains;
33use crate::propagation::ExplanationContext;
34#[cfg(feature = "check-propagations")]
35use crate::propagation::InferenceCheckers;
36use crate::propagation::NotificationContext;
37use crate::propagation::PropagationContext;
38use crate::propagation::Propagator;
39use crate::propagation::PropagatorConstructor;
40use crate::propagation::PropagatorConstructorContext;
41use crate::propagation::PropagatorId;
42use crate::propagation::store::PropagatorStore;
43use crate::pumpkin_assert_advanced;
44use crate::pumpkin_assert_eq_simple;
45use crate::pumpkin_assert_extreme;
46use crate::pumpkin_assert_simple;
47use crate::results::SolutionReference;
48use crate::state::PropagatorHandle;
49use crate::statistics::StatisticLogger;
50use crate::statistics::log_statistic;
51use crate::variables::DomainId;
52use crate::variables::IntegerVariable;
53use crate::variables::Literal;
54
55/// The [`State`] is the container of variables and propagators.
56///
57/// [`State`] implements [`Clone`], and cloning the [`State`] will create a fresh copy of the
58/// [`State`]. If the [`State`] is large, this may be extremely expensive.
59#[derive(Debug, Clone)]
60pub struct State {
61    /// The list of propagators; propagators live here and are queried when events (domain changes)
62    /// happen.
63    pub(crate) propagators: PropagatorStore,
64    /// Tracks information related to the assignments of integer variables.
65    pub(crate) assignments: Assignments,
66    /// Keep track of trailed values (i.e. values which automatically backtrack).
67    pub(crate) trailed_values: TrailedValues,
68    /// The names of the variables in the solver.
69    pub(crate) variable_names: VariableNames,
70    /// Dictates the order in which propagators will be called to propagate.
71    pub(crate) propagator_queue: PropagatorQueue,
72    /// Handles storing information about propagation reasons, which are used later to construct
73    /// explanations during conflict analysis.
74    pub(crate) reason_store: ReasonStore,
75    /// Component responsible for providing notifications for changes to the domains of variables
76    /// and/or the polarity [Predicate]s
77    pub(crate) notification_engine: NotificationEngine,
78
79    /// The [`ConstraintTag`]s generated for this proof.
80    pub(crate) constraint_tags: KeyGenerator<ConstraintTag>,
81
82    statistics: StateStatistics,
83
84    /// Inference checkers to run in the propagation loop.
85    checkers: HashMap<InferenceCode, Vec<BoxedChecker<Predicate>>>,
86}
87
88create_statistics_struct!(StateStatistics {
89    num_propagators_called: usize,
90    num_propagations: usize,
91    num_conflicts: usize,
92    /// The number of levels which were backjumped.
93    ///
94    /// For an individual backtrack due to a learned nogood, this is calculated according to the
95    /// formula `CurrentDecisionLevel - 1 - BacktrackLevel` (i.e. how many levels (in total) has
96    /// the solver backtracked and not backjumped)
97    sum_of_backjumps: u64,
98    /// The number of times a backjump (i.e. backtracking more than a single decision level due to
99    /// a learned nogood) occurs.
100    num_backjumps: u64,
101});
102
103/// Information concerning the conflict returned by [`State::propagate_to_fixed_point`].
104///
105/// Two (related) conflicts can happen:
106/// 1) a propagator explicitly detects a conflict.
107/// 2) a propagator post a domain change that results in a variable having an empty domain.
108#[derive(Debug, Clone, PartialEq, Eq)]
109pub enum Conflict {
110    /// A conflict raised explicitly by a propagator.
111    Propagator(PropagatorConflict),
112    /// A conflict caused by an empty domain for a variable occurring.
113    EmptyDomain(EmptyDomainConflict),
114}
115
116impl From<EmptyDomainConflict> for Conflict {
117    fn from(value: EmptyDomainConflict) -> Self {
118        Conflict::EmptyDomain(value)
119    }
120}
121
122impl From<PropagatorConflict> for Conflict {
123    fn from(value: PropagatorConflict) -> Self {
124        Conflict::Propagator(value)
125    }
126}
127
128/// A conflict because a domain became empty.
129#[derive(Clone, Debug, PartialEq, Eq)]
130pub struct EmptyDomainConflict {
131    /// The predicate that caused a domain to become empty.
132    pub trigger_predicate: Predicate,
133    /// The reason for [`EmptyDomainConflict::trigger_predicate`] to be true.
134    pub(crate) trigger_reason: ReasonRef,
135    /// The [`InferenceCode`] that accompanies [`EmptyDomainConflict::trigger_reason`].
136    pub(crate) trigger_inference_code: InferenceCode,
137}
138
139impl EmptyDomainConflict {
140    /// The domain that became empty.
141    pub fn domain(&self) -> DomainId {
142        self.trigger_predicate.get_domain()
143    }
144
145    /// Returns the reason for the [`EmptyDomainConflict::trigger_predicate`] being propagated to
146    /// true while it is already false in the [`State`].
147    pub fn get_reason(
148        &self,
149        state: &mut State,
150        reason_buffer: &mut (impl Extend<Predicate> + AsRef<[Predicate]>),
151        current_nogood: CurrentNogood,
152    ) {
153        let _ = state.reason_store.get_or_compute(
154            self.trigger_reason,
155            ExplanationContext::new(
156                &state.assignments,
157                current_nogood,
158                state.trail_len(),
159                &mut state.notification_engine,
160            ),
161            &mut state.propagators,
162            reason_buffer,
163        );
164    }
165}
166
167impl Default for State {
168    fn default() -> Self {
169        let mut result = Self {
170            assignments: Default::default(),
171            trailed_values: TrailedValues::default(),
172            variable_names: VariableNames::default(),
173            propagator_queue: PropagatorQueue::default(),
174            propagators: PropagatorStore::default(),
175            reason_store: ReasonStore::default(),
176            notification_engine: NotificationEngine::default(),
177            statistics: StateStatistics::default(),
178            constraint_tags: KeyGenerator::default(),
179            checkers: HashMap::default(),
180        };
181        // As a convention, the assignments contain a dummy domain_id=0, which represents a 0-1
182        // variable that is assigned to one. We use it to represent predicates that are
183        // trivially true. We need to adjust other data structures to take this into account.
184        let dummy_id = Predicate::trivially_true().get_domain();
185
186        result.variable_names.add_integer(dummy_id, "Dummy".into());
187        assert!(dummy_id.id() == 0);
188        assert!(result.assignments.get_lower_bound(dummy_id) == 1);
189        assert!(result.assignments.get_upper_bound(dummy_id) == 1);
190
191        result
192    }
193}
194
195impl State {
196    pub(crate) fn log_statistics(&self, verbose: bool) {
197        log_statistic("variables", self.assignments.num_domains());
198        log_statistic("propagators", self.propagators.num_propagators());
199        log_statistic("failures", self.statistics.num_conflicts);
200        log_statistic("propagations", self.statistics.num_propagators_called);
201        log_statistic("nogoods", self.statistics.num_conflicts);
202        if verbose {
203            log_statistic(
204                "numAtomicConstraintsPropagated",
205                self.statistics.num_propagations,
206            );
207            for (index, propagator) in self.propagators.iter_propagators().enumerate() {
208                propagator.log_statistics(StatisticLogger::new([
209                    propagator.name(),
210                    "number",
211                    index.to_string().as_str(),
212                ]));
213            }
214        }
215    }
216}
217
218/// Operations to create .
219impl State {
220    /// Create a new [`ConstraintTag`].
221    pub fn new_constraint_tag(&mut self) -> ConstraintTag {
222        self.constraint_tags.next_key()
223    }
224
225    /// Creates a new Boolean (0-1) variable.
226    ///
227    /// The name is used in solver traces to identify individual domains. They are required to be
228    /// unique. If the state already contains a domain with the given name, then this function
229    /// will panic.
230    ///
231    /// Creation of new [`Literal`]s is not influenced by the current checkpoint of the state.
232    /// If a [`Literal`] is created at a non-zero checkpoint, then it will _not_ 'disappear'
233    /// when backtracking past the checkpoint where the domain was created.
234    pub fn new_literal(&mut self, name: Option<Arc<str>>) -> Literal {
235        let domain_id = self.new_interval_variable(0, 1, name);
236        Literal::new(domain_id)
237    }
238
239    /// Creates a new interval variable with the given lower and upper bound.
240    ///
241    /// The name is used in solver traces to identify individual domains. They are required to be
242    /// unique. If the state already contains a domain with the given name, then this function
243    /// will panic.
244    ///
245    /// Creation of new domains is not influenced by the current checkpoint of the state. If
246    /// a domain is created at a non-zero checkpoint, then it will _not_ 'disappear' when
247    /// backtracking past the checkpoint where the domain was created.)
248    pub fn new_interval_variable(
249        &mut self,
250        lower_bound: i32,
251        upper_bound: i32,
252        name: Option<Arc<str>>,
253    ) -> DomainId {
254        let domain_id = self.assignments.grow(lower_bound, upper_bound);
255
256        if let Some(name) = name {
257            self.variable_names.add_integer(domain_id, name);
258        }
259
260        self.notification_engine.grow();
261
262        domain_id
263    }
264
265    /// Creates a new sparse domain with the given values.
266    ///
267    /// Note that this is implemented as an interval domain with explicit holes in the domain. For
268    /// very sparse domains, this can result in a high memory overhead.
269    ///
270    /// For more information on creation of domains, see [`State::new_interval_variable`].
271    pub fn new_sparse_variable(&mut self, values: Vec<i32>, name: Option<String>) -> DomainId {
272        let domain_id = self.assignments.create_new_integer_variable_sparse(values);
273
274        if let Some(name) = name {
275            self.variable_names.add_integer(domain_id, name.into());
276        }
277
278        self.notification_engine.grow();
279
280        domain_id
281    }
282}
283
284/// Operations to retrieve information about values
285impl State {
286    /// Returns the lower-bound of the given `variable`.
287    pub fn lower_bound<Var: IntegerVariable>(&self, variable: Var) -> i32 {
288        variable.lower_bound(&self.assignments)
289    }
290
291    /// Returns the upper-bound of the given `variable`.
292    pub fn upper_bound<Var: IntegerVariable>(&self, variable: Var) -> i32 {
293        variable.upper_bound(&self.assignments)
294    }
295
296    /// Returns whether the given `variable` contains the provided `value`.
297    pub fn contains<Var: IntegerVariable>(&self, variable: Var, value: i32) -> bool {
298        variable.contains(&self.assignments, value)
299    }
300
301    /// If the given `variable` is fixed, then [`Some`] containing the assigned value is
302    /// returned. Otherwise, [`None`] is returned.
303    pub fn fixed_value<Var: IntegerVariable>(&self, variable: Var) -> Option<i32> {
304        (self.lower_bound(variable.clone()) == self.upper_bound(variable.clone()))
305            .then(|| self.lower_bound(variable))
306    }
307
308    /// Returns the truth value of the provided [`Predicate`].
309    ///
310    /// If the [`Predicate`] is assigned in the current [`State`] then [`Some`] containing whether
311    /// the [`Predicate`] is satisfied or falsified is returned. Otherwise, [`None`] is returned.
312    pub fn truth_value(&self, predicate: Predicate) -> Option<bool> {
313        self.assignments.evaluate_predicate(predicate)
314    }
315
316    /// If the provided [`Predicate`] is satisfied then it returns [`Some`] containing the
317    /// checkpoint at which the [`Predicate`] became satisfied. Otherwise, [`None`] is returned.
318    pub fn get_checkpoint_for_predicate(&self, predicate: Predicate) -> Option<usize> {
319        self.assignments.get_checkpoint_for_predicate(&predicate)
320    }
321
322    /// Returns the truth value of the provided [`Literal`].
323    ///
324    /// If the [`Literal`] is assigned in the current [`State`] then [`Some`] containing whether
325    /// the [`Literal`] is satisfied or falsified is returned. Otherwise, [`None`] is returned.
326    pub fn get_literal_value(&self, literal: Literal) -> Option<bool> {
327        self.truth_value(literal.get_true_predicate())
328    }
329
330    /// Returns the number of created checkpoints.
331    pub fn get_checkpoint(&self) -> usize {
332        self.assignments.get_checkpoint()
333    }
334}
335
336/// Operations for retrieving information about trail
337impl State {
338    /// Returns the length of the trail.
339    pub(crate) fn trail_len(&self) -> usize {
340        self.assignments.num_trail_entries()
341    }
342
343    /// Returns the [`Predicate`] at the provided `trail_index`.
344    pub(crate) fn trail_entry(&self, trail_index: usize) -> ConstraintProgrammingTrailEntry {
345        self.assignments.get_trail_entry(trail_index)
346    }
347
348    /// Returns whether the provided [`Predicate`] is explicitly on the trail.
349    ///
350    /// For example, if we post the [`Predicate`] [x >= v], then the predicate [x >= v - 1] is
351    /// not explicity on the trail.
352    pub fn is_on_trail(&self, predicate: Predicate) -> bool {
353        let trail_position = self.trail_position(predicate);
354
355        trail_position.is_some_and(|trail_position| {
356            self.assignments.trail[trail_position].predicate == predicate
357        })
358    }
359
360    /// Returns whether the trail position of the provided [`Predicate`].
361    pub fn trail_position(&self, predicate: Predicate) -> Option<usize> {
362        self.assignments.get_trail_position(&predicate)
363    }
364}
365
366/// Operations for adding constraints.
367impl State {
368    /// Enqueues the propagator with [`PropagatorHandle`] `handle` for propagation.
369    #[deprecated]
370    pub(crate) fn enqueue_propagator<P: Propagator>(&mut self, handle: PropagatorHandle<P>) {
371        let priority = self.propagators[handle.propagator_id()].priority();
372        self.propagator_queue
373            .enqueue_propagator(handle.propagator_id(), priority);
374    }
375
376    /// Add a new propagator to the [`State`]. The constructor for that propagator should
377    /// subscribe to the appropriate domain events so that the propagator is called when
378    /// necessary.
379    ///
380    /// While the propagator is added to the queue for propagation, this function does _not_
381    /// trigger a round of propagation. An explicit call to [`State::propagate_to_fixed_point`] is
382    /// necessary to run the new propagator for the first time.
383    pub fn add_propagator<Constructor>(
384        &mut self,
385        constructor: Constructor,
386    ) -> PropagatorHandle<Constructor::PropagatorImpl>
387    where
388        Constructor: PropagatorConstructor,
389        Constructor::PropagatorImpl: 'static,
390    {
391        #[cfg(feature = "check-propagations")]
392        constructor.add_inference_checkers(InferenceCheckers::new(self));
393
394        let original_handle: PropagatorHandle<Constructor::PropagatorImpl> =
395            self.propagators.new_propagator().key();
396
397        let constructor_context =
398            PropagatorConstructorContext::new(original_handle.propagator_id(), self);
399        let propagator = constructor.create(constructor_context);
400
401        pumpkin_assert_simple!(
402            propagator.priority() as u8 <= 3,
403            "The propagator priority exceeds 3.
404             Currently we only support values up to 3,
405             but this can easily be changed if there is a good reason."
406        );
407
408        let slot = self.propagators.new_propagator();
409        let handle = slot.populate(propagator);
410
411        pumpkin_assert_eq_simple!(handle.propagator_id(), original_handle.propagator_id());
412
413        #[allow(deprecated, reason = "Will be refactored")]
414        self.enqueue_propagator(handle);
415
416        handle
417    }
418
419    /// Add an inference checker to the state.
420    ///
421    /// The inference checker will be used to check propagations performed during
422    /// [`Self::propagate_to_fixed_point`], if the `check-propagations` feature is enabled.
423    ///
424    /// Multiple inference checkers may be added for the same inference code. In that case, if
425    /// any checker accepts the inference, the inference is accepted.
426    pub fn add_inference_checker(
427        &mut self,
428        inference_code: InferenceCode,
429        checker: Box<dyn InferenceChecker<Predicate>>,
430    ) {
431        let checkers = self.checkers.entry(inference_code).or_default();
432        checkers.push(BoxedChecker::from(checker));
433    }
434}
435
436/// Operations for retrieving propagators.
437impl State {
438    /// Get a reference to the propagator identified by the given handle.
439    ///
440    /// For an exclusive reference, use [`State::get_propagator_mut`].
441    pub fn get_propagator<P: Propagator>(&self, handle: PropagatorHandle<P>) -> Option<&P> {
442        self.propagators.get_propagator(handle)
443    }
444
445    /// Get an exclusive reference to the propagator identified by the given handle.
446    pub fn get_propagator_mut<P: Propagator>(
447        &mut self,
448        handle: PropagatorHandle<P>,
449    ) -> Option<&mut P> {
450        self.propagators.get_propagator_mut(handle)
451    }
452
453    /// Get an exclusive reference to the propagator identified by the given handle and a context
454    /// which can be used for propagation.
455    pub(crate) fn get_propagator_mut_with_context<P: Propagator>(
456        &mut self,
457        handle: PropagatorHandle<P>,
458    ) -> (Option<&mut P>, PropagationContext<'_>) {
459        (
460            self.propagators.get_propagator_mut(handle),
461            PropagationContext::new(
462                &mut self.trailed_values,
463                &mut self.assignments,
464                &mut self.reason_store,
465                &mut self.notification_engine,
466                handle.propagator_id(),
467            ),
468        )
469    }
470}
471
472/// Operations for modifying the state.
473impl State {
474    /// Apply a [`Predicate`] to the [`State`].
475    ///
476    /// Returns `true` if a change to a domain occured, and `false` if the given [`Predicate`] was
477    /// already true.
478    ///
479    /// If a domain becomes empty due to this operation, an [`EmptyDomain`] error is returned.
480    ///
481    /// This method does _not_ perform any propagation. For that, an explicit call to
482    /// [`State::propagate_to_fixed_point`] is required. This allows the
483    /// posting of multiple predicates before the entire propagation engine is invoked.
484    ///
485    /// A call to [`State::restore_to`] that goes past the checkpoint at which a [`Predicate`]
486    /// was posted will undo the effect of that [`Predicate`]. See the documentation of
487    /// [`State::new_checkpoint`] and
488    /// [`State::restore_to`] for more information.
489    pub fn post(&mut self, predicate: Predicate) -> Result<bool, EmptyDomain> {
490        self.assignments
491            .post_predicate(predicate, None, &mut self.notification_engine)
492    }
493
494    #[cfg(test)]
495    fn post_with_reason(
496        &mut self,
497        predicate: Predicate,
498        reason: impl Into<Reason>,
499        inference_code: InferenceCode,
500        propagator_id: PropagatorId,
501    ) -> Result<(), EmptyDomainConflict> {
502        let slot = self.reason_store.new_slot();
503
504        let modification_result = self.assignments.post_predicate(
505            predicate,
506            Some((slot.reason_ref(), inference_code.clone())),
507            &mut self.notification_engine,
508        );
509
510        match modification_result {
511            Ok(false) => Ok(()),
512            Ok(true) => {
513                use crate::propagation::build_reason;
514
515                let _ = slot.populate(propagator_id, build_reason(reason, None));
516                Ok(())
517            }
518            Err(EmptyDomain) => {
519                use crate::propagation::build_reason;
520
521                let _ = slot.populate(propagator_id, build_reason(reason, None));
522                let (trigger_predicate, trigger_reason, trigger_inference_code) =
523                    self.assignments.remove_last_trail_element();
524
525                Err(EmptyDomainConflict {
526                    trigger_predicate,
527                    trigger_reason,
528                    trigger_inference_code,
529                })
530            }
531        }
532    }
533
534    /// Create a checkpoint of the current [`State`], that can be returned to with
535    /// [`State::restore_to`].
536    ///
537    /// The current checkpoint can be retrieved using the method [`State::get_checkpoint`].
538    ///
539    /// If the state is not at fixed-point, then this method will panic.
540    ///
541    /// # Example
542    /// ```
543    /// use pumpkin_core::predicate;
544    /// use pumpkin_core::state::State;
545    ///
546    /// let mut state = State::default();
547    /// let variable = state.new_interval_variable(1, 10, Some("x1".into()));
548    ///
549    /// assert_eq!(state.get_checkpoint(), 0);
550    ///
551    /// state.new_checkpoint();
552    ///
553    /// assert_eq!(state.get_checkpoint(), 1);
554    ///
555    /// state
556    ///     .post(predicate![variable <= 5])
557    ///     .expect("The lower bound is 1 so no conflict");
558    /// assert_eq!(state.upper_bound(variable), 5);
559    ///
560    /// state.restore_to(0);
561    ///
562    /// assert_eq!(state.get_checkpoint(), 0);
563    /// assert_eq!(state.upper_bound(variable), 10);
564    /// ```
565    pub fn new_checkpoint(&mut self) {
566        pumpkin_assert_simple!(
567            self.propagator_queue.is_empty(),
568            "Can only create a new checkpoint when all propagation has occurred"
569        );
570        self.assignments.new_checkpoint();
571        self.notification_engine.new_checkpoint();
572        self.trailed_values.new_checkpoint();
573        self.reason_store.new_checkpoint();
574    }
575
576    /// Restore to the given checkpoint and return the [`DomainId`]s which were fixed before
577    /// restoring, with their assigned values.
578    ///
579    /// If the provided checkpoint is equal to the current checkpoint, this is a no-op. If
580    /// the provided checkpoint is larger than the current checkpoint, this method will
581    /// panic.
582    ///
583    /// See [`State::new_checkpoint`] for an example.
584    pub fn restore_to(&mut self, checkpoint: usize) -> Vec<(DomainId, i32)> {
585        pumpkin_assert_simple!(checkpoint <= self.get_checkpoint());
586
587        self.statistics.sum_of_backjumps += (self.get_checkpoint() - 1 - checkpoint) as u64;
588        if self.get_checkpoint() - checkpoint > 1 {
589            self.statistics.num_backjumps += 1;
590        }
591
592        if checkpoint == self.get_checkpoint() {
593            return vec![];
594        }
595
596        let unfixed_after_backtracking = self
597            .assignments
598            .synchronise(checkpoint, &mut self.notification_engine);
599        self.trailed_values.synchronise(checkpoint);
600        self.reason_store.synchronise(checkpoint);
601
602        self.propagator_queue.clear();
603        // For now all propagators are called to synchronise, in the future this will be improved in
604        // two ways:
605        //      + allow incremental synchronisation
606        //      + only call the subset of propagators that were notified since last backtrack
607        for propagator in self.propagators.iter_propagators_mut() {
608            let mut context = NotificationContext::new(&mut self.trailed_values, &self.assignments);
609
610            propagator.synchronise(context.reborrow());
611        }
612
613        let _ = self.notification_engine.process_backtrack_events(
614            &mut self.assignments,
615            &mut self.trailed_values,
616            &mut self.propagators,
617        );
618        self.notification_engine.clear_event_drain();
619
620        self.notification_engine
621            .update_last_notified_index(&mut self.assignments);
622        // Should be done after the assignments and trailed values have been synchronised
623        self.notification_engine.synchronise(
624            checkpoint,
625            &self.assignments,
626            &mut self.trailed_values,
627        );
628
629        unfixed_after_backtracking
630    }
631
632    /// Performs a single call to [`Propagator::propagate`] for the propagator with the provided
633    /// [`PropagatorId`].
634    ///
635    /// Other propagators could be enqueued as a result of the changes made by the propagated
636    /// propagator but a call to [`State::propagate_to_fixed_point`] is
637    /// required for further propagation to occur.
638    ///
639    /// It could be that the current [`State`] implies a conflict by propagation. In that case, an
640    /// [`Err`] with [`Conflict`] is returned.
641    ///
642    /// Once the [`State`] is conflicting, then the only operation that is defined is
643    /// [`State::restore_to`]. All other operations and queries on the state are undetermined.
644    fn propagate(&mut self, propagator_id: PropagatorId) -> Result<(), Conflict> {
645        self.statistics.num_propagators_called += 1;
646
647        let num_trail_entries_before = self.assignments.num_trail_entries();
648
649        let propagation_status = {
650            let propagator = &mut self.propagators[propagator_id];
651            let context = PropagationContext::new(
652                &mut self.trailed_values,
653                &mut self.assignments,
654                &mut self.reason_store,
655                &mut self.notification_engine,
656                propagator_id,
657            );
658            propagator.propagate(context)
659        };
660
661        #[cfg(feature = "check-propagations")]
662        self.check_propagations(num_trail_entries_before);
663
664        match propagation_status {
665            Ok(_) => {
666                // Notify other propagators of the propagations and continue.
667                self.notification_engine
668                    .notify_propagators_about_domain_events(
669                        &mut self.assignments,
670                        &mut self.trailed_values,
671                        &mut self.propagators,
672                        &mut self.propagator_queue,
673                    );
674                pumpkin_assert_extreme!(
675                    DebugHelper::debug_check_propagations(
676                        num_trail_entries_before,
677                        propagator_id,
678                        &self.trailed_values,
679                        &self.assignments,
680                        &mut self.reason_store,
681                        &mut self.propagators,
682                        &self.notification_engine
683                    ),
684                    "Checking the propagations performed by the propagator led to inconsistencies!"
685                );
686            }
687            Err(conflict) => {
688                #[cfg(feature = "check-propagations")]
689                self.check_conflict(&conflict);
690
691                self.statistics.num_conflicts += 1;
692                if let Conflict::Propagator(inner) = &conflict {
693                    pumpkin_assert_advanced!(DebugHelper::debug_reported_failure(
694                        &self.trailed_values,
695                        &self.assignments,
696                        &inner.conjunction,
697                        &self.propagators[propagator_id],
698                        propagator_id,
699                        &self.notification_engine
700                    ));
701                }
702
703                return Err(conflict);
704            }
705        }
706        Ok(())
707    }
708
709    /// Check the inference that triggered the given conflict.
710    ///
711    /// Does nothing when the conflict is an empty domain.
712    ///
713    /// Panics when the inference checker rejects the conflict.
714    #[cfg(feature = "check-propagations")]
715    fn check_conflict(&mut self, conflict: &Conflict) {
716        if let Conflict::Propagator(propagator_conflict) = conflict {
717            self.run_checker(
718                propagator_conflict.conjunction.clone(),
719                None,
720                &propagator_conflict.inference_code,
721            );
722        }
723    }
724
725    /// For every item on the trail starting at index `first_propagation_index`, run the
726    /// inference checker for it.
727    ///
728    /// This method should be called after every propagator invocation, so all elements on the
729    /// trail starting at `first_propagation_index` should be propagations. Otherwise this function
730    /// will panic.
731    ///
732    /// If the checker rejects the inference, this method panics.
733    #[cfg(feature = "check-propagations")]
734    pub(crate) fn check_propagations(&mut self, first_propagation_index: usize) {
735        let mut reason_buffer = vec![];
736
737        for trail_index in first_propagation_index..self.assignments.num_trail_entries() {
738            let entry = self.assignments.get_trail_entry(trail_index);
739
740            let (reason_ref, inference_code) = entry
741                .reason
742                .expect("propagations should only be checked after propagations");
743
744            reason_buffer.clear();
745            let reason_exists = self.reason_store.get_or_compute(
746                reason_ref,
747                ExplanationContext::without_working_nogood(
748                    &self.assignments,
749                    trail_index,
750                    &mut self.notification_engine,
751                ),
752                &mut self.propagators,
753                &mut reason_buffer,
754            );
755            assert!(reason_exists, "all propagations have reasons");
756
757            self.run_checker(
758                std::mem::take(&mut reason_buffer),
759                Some(entry.predicate),
760                &inference_code,
761            );
762        }
763    }
764
765    /// Performs fixed-point propagation using the propagators defined in the [`State`].
766    ///
767    /// The posted [`Predicate`]s (using [`State::post`]) and added propagators (using
768    /// [`State::add_propagator`]) cause propagators to be enqueued when the events that
769    /// they have subscribed to are triggered. As propagation causes more changes to be made,
770    /// more propagators are enqueued. This continues until applying all (enqueued)
771    /// propagators leads to no more domain changes.
772    ///
773    /// It could be that the current [`State`] implies a conflict by propagation. In that case, an
774    /// error with [`Conflict`] is returned.
775    ///
776    /// Once the [`State`] is conflicting, then the only operation that is defined is
777    /// [`State::restore_to`]. All other operations and queries on the state are unspecified.
778    pub fn propagate_to_fixed_point(&mut self) -> Result<(), Conflict> {
779        // The initial domain events are due to the decision predicate.
780        self.notification_engine
781            .notify_propagators_about_domain_events(
782                &mut self.assignments,
783                &mut self.trailed_values,
784                &mut self.propagators,
785                &mut self.propagator_queue,
786            );
787
788        // Keep propagating until there are unprocessed propagators, or a conflict is detected.
789        while let Some(propagator_id) = self.propagator_queue.pop() {
790            self.propagate(propagator_id)?;
791        }
792
793        // Only check fixed point propagation if there was no reported conflict,
794        // since otherwise the state may be inconsistent.
795        pumpkin_assert_extreme!(DebugHelper::debug_fixed_point_propagation(
796            &self.trailed_values,
797            &self.assignments,
798            &self.propagators,
799            &self.notification_engine
800        ));
801
802        Ok(())
803    }
804}
805
806#[cfg(feature = "check-propagations")]
807impl State {
808    /// Run the checker for the given inference code on the given inference.
809    fn run_checker(
810        &self,
811        premises: impl IntoIterator<Item = Predicate>,
812        consequent: Option<Predicate>,
813        inference_code: &InferenceCode,
814    ) {
815        let premises: Vec<_> = premises.into_iter().collect();
816
817        let checkers = self
818            .checkers
819            .get(inference_code)
820            .map(|vec| vec.as_slice())
821            .unwrap_or(&[]);
822
823        assert!(
824            !checkers.is_empty(),
825            "missing checker for inference code {inference_code:?}"
826        );
827
828        let any_checker_accepts_inference = checkers.iter().any(|checker| {
829            // Construct the variable state for the conflict check.
830            let variable_state = VariableState::prepare_for_conflict_check(
831                premises.clone(),
832                consequent,
833            )
834            .unwrap_or_else(|domain| {
835                panic!(
836                    "inconsistent atomics over domain {domain:?} in inference by {inference_code:?}"
837                )
838            });
839
840            checker.check(variable_state, &premises, consequent.as_ref())
841        });
842
843        assert!(
844            any_checker_accepts_inference,
845            "checker for inference code {:?} fails on inference {:?} -> {:?}",
846            inference_code,
847            premises.into_iter().collect::<Vec<_>>(),
848            consequent,
849        );
850    }
851}
852
853impl State {
854    /// This is a temporary accessor to help refactoring.
855    pub(crate) fn get_solution_reference(&self) -> SolutionReference<'_> {
856        SolutionReference::new(&self.assignments)
857    }
858
859    /// Returns a mapping of [`DomainId`] to variable name.
860    pub(crate) fn variable_names(&self) -> &VariableNames {
861        &self.variable_names
862    }
863
864    pub(crate) fn get_propagation_reason_trail_entry(
865        &mut self,
866        trail_position: usize,
867        reason_buffer: &mut (impl Extend<Predicate> + AsRef<[Predicate]>),
868    ) {
869        let entry = self.trail_entry(trail_position);
870        let (reason_ref, _) = entry
871            .reason
872            .expect("Added by a propagator and must therefore have a reason");
873        let _ = self.reason_store.get_or_compute(
874            reason_ref,
875            ExplanationContext::without_working_nogood(
876                &self.assignments,
877                trail_position,
878                &mut self.notification_engine,
879            ),
880            &mut self.propagators,
881            reason_buffer,
882        );
883    }
884    /// Get the reason for a predicate being true and store it in `reason_buffer`; additionally, if
885    /// the provided [`Predicate`] is explicitly on the trail, this method will return the
886    /// corresponding trail index.
887    ///
888    /// The provided `current_nogood` can be used by the propagator to provide a different reason;
889    /// use [`CurrentNogood::empty`] otherwise.
890    ///
891    /// All the predicates in the returned slice will evaluate to `true`.
892    ///
893    /// If the provided predicate is not true, then this method will panic.
894    #[allow(unused, reason = "Will be part of public API")]
895    pub fn get_propagation_reason(
896        &mut self,
897        predicate: Predicate,
898        reason_buffer: &mut (impl Extend<Predicate> + AsRef<[Predicate]>),
899        current_nogood: CurrentNogood<'_>,
900    ) -> Option<usize> {
901        // TODO: this function could be put into the reason store
902
903        // Note that this function can only be called with propagations, and never decision
904        // predicates. Furthermore only predicate from the current checkpoint will be
905        // considered. This is due to how the 1uip conflict analysis works: it scans the
906        // predicates in reverse order of assignment, and stops as soon as there is only one
907        // predicate from the current checkpoint in the learned nogood.
908
909        // This means that the procedure would never ask for the reason of the decision predicate
910        // from the current checkpoint, because that would mean that all other predicates from
911        // the current checkpoint have been removed from the nogood, and the decision
912        // predicate is the only one left, but in that case, the 1uip would terminate since
913        // there would be only one predicate from the current checkpoint. For this
914        // reason, it is safe to assume that in the following, that any input predicate is
915        // indeed a propagated predicate.
916        if self.assignments.is_initial_bound(predicate) {
917            return None;
918        }
919
920        let trail_position = self
921            .assignments
922            .get_trail_position(&predicate)
923            .unwrap_or_else(|| panic!("The predicate {predicate:?} must be true during conflict analysis. Bounds were {},{}", self.lower_bound(predicate.get_domain()), self.upper_bound(predicate.get_domain())));
924
925        let trail_entry = self.assignments.get_trail_entry(trail_position);
926
927        // We distinguish between three cases:
928        // 1) The predicate is explicitly present on the trail.
929        if trail_entry.predicate == predicate {
930            let (reason_ref, inference_code) = trail_entry
931                .reason
932                .expect("Cannot be a null reason for propagation.");
933
934            let explanation_context = ExplanationContext::new(
935                &self.assignments,
936                current_nogood,
937                trail_position,
938                &mut self.notification_engine,
939            );
940
941            let reason_exists = self.reason_store.get_or_compute(
942                reason_ref,
943                explanation_context,
944                &mut self.propagators,
945                reason_buffer,
946            );
947
948            assert!(reason_exists, "reason reference should not be stale");
949
950            Some(trail_position)
951        }
952        // 2) The predicate is true due to a propagation, and not explicitly on the trail.
953        // It is necessary to further analyse what was the reason for setting the predicate true.
954        else {
955            // The reason for propagation depends on:
956            // 1) The predicate on the trail at the moment the input predicate became true, and
957            // 2) The input predicate.
958            match (
959                trail_entry.predicate.get_predicate_type(),
960                predicate.get_predicate_type(),
961            ) {
962                (PredicateType::LowerBound, PredicateType::LowerBound) => {
963                    let trail_lower_bound = trail_entry.predicate.get_right_hand_side();
964                    let domain_id = predicate.get_domain();
965                    let input_lower_bound = predicate.get_right_hand_side();
966                    // Both the input predicate and the trail predicate are lower bound
967                    // literals. Two cases to consider:
968                    // 1) The trail predicate has a greater right-hand side, meaning
969                    //  the reason for the input predicate is true is because a stronger
970                    //  right-hand side predicate was posted. We can reuse the same
971                    //  reason as for the trail bound.
972                    //  todo: could consider lifting here, since the trail bound
973                    //  might be too strong.
974                    if trail_lower_bound > input_lower_bound {
975                        reason_buffer.extend(std::iter::once(trail_entry.predicate));
976                    }
977                    // Otherwise, the input bound is strictly greater than the trailed
978                    // bound. This means the reason is due to holes in the domain.
979                    else {
980                        // Note that the bounds cannot be equal.
981                        // If the bound were equal, the predicate would be explicitly on the
982                        // trail, so we would have detected this case earlier.
983                        pumpkin_assert_simple!(trail_lower_bound < input_lower_bound);
984
985                        // The reason for the propagation of the input predicate [x >= a] is
986                        // because [x >= a-1] & [x != a]. Conflict analysis will then
987                        // recursively decompose these further.
988
989                        // Note that we do not need to worry about decreasing the lower
990                        // bounds so much so that it reaches its root lower bound, for which
991                        // there is no reason since it is given as input to the problem.
992                        // We cannot reach the original lower bound since in the 1uip, we
993                        // only look for reasons for predicates from the current decision
994                        // level, and we never look for reasons at the root level.
995
996                        let one_less_bound_predicate =
997                            predicate!(domain_id >= input_lower_bound - 1);
998
999                        let not_equals_predicate = predicate!(domain_id != input_lower_bound - 1);
1000                        reason_buffer.extend(std::iter::once(one_less_bound_predicate));
1001                        reason_buffer.extend(std::iter::once(not_equals_predicate));
1002                    }
1003                }
1004                (PredicateType::LowerBound, PredicateType::NotEqual) => {
1005                    let trail_lower_bound = trail_entry.predicate.get_right_hand_side();
1006                    let not_equal_constant = predicate.get_right_hand_side();
1007                    // The trail entry is a lower bound literal,
1008                    // and the input predicate is a not equals.
1009                    // Only one case to consider:
1010                    // The trail lower bound is greater than the not_equals_constant,
1011                    // so it safe to take the reason from the trail.
1012                    // todo: lifting could be used here
1013                    pumpkin_assert_simple!(trail_lower_bound > not_equal_constant);
1014                    reason_buffer.extend(std::iter::once(trail_entry.predicate));
1015                }
1016                (PredicateType::LowerBound, PredicateType::Equal) => {
1017                    let domain_id = predicate.get_domain();
1018                    let equality_constant = predicate.get_right_hand_side();
1019                    // The input predicate is an equality predicate, and the trail predicate
1020                    // is a lower bound predicate. This means that the time of posting the
1021                    // trail predicate is when the input predicate became true.
1022
1023                    // Note that the input equality constant does _not_ necessarily equal
1024                    // the trail lower bound. This would be the
1025                    // case when the the trail lower bound is lower than the input equality
1026                    // constant, but due to holes in the domain, the lower bound got raised
1027                    // to just the value of the equality constant.
1028                    // For example, {1, 2, 3, 10}, then posting [x >= 5] will raise the
1029                    // lower bound to x >= 10.
1030
1031                    let predicate_lb = predicate!(domain_id >= equality_constant);
1032                    let predicate_ub = predicate!(domain_id <= equality_constant);
1033                    reason_buffer.extend(std::iter::once(predicate_lb));
1034                    reason_buffer.extend(std::iter::once(predicate_ub));
1035                }
1036                (PredicateType::UpperBound, PredicateType::UpperBound) => {
1037                    let trail_upper_bound = trail_entry.predicate.get_right_hand_side();
1038                    let domain_id = predicate.get_domain();
1039                    let input_upper_bound = predicate.get_right_hand_side();
1040                    // Both the input and trail predicates are upper bound predicates.
1041                    // There are two scenarios to consider:
1042                    // 1) The input upper bound is greater than the trail upper bound, meaning that
1043                    //    the reason for the input predicate is the propagation of a stronger upper
1044                    //    bound. We can safely use the reason for of the trail predicate as the
1045                    //    reason for the input predicate.
1046                    // todo: lifting could be applied here.
1047                    if trail_upper_bound < input_upper_bound {
1048                        reason_buffer.extend(std::iter::once(trail_entry.predicate));
1049                    } else {
1050                        // I think it cannot be that the bounds are equal, since otherwise we
1051                        // would have found the predicate explicitly on the trail.
1052                        pumpkin_assert_simple!(trail_upper_bound > input_upper_bound);
1053
1054                        // The input upper bound is greater than the trail predicate, meaning
1055                        // that holes in the domain also played a rule in lowering the upper
1056                        // bound.
1057
1058                        // The reason of the input predicate [x <= a] is computed recursively as
1059                        // the reason for [x <= a + 1] & [x != a + 1].
1060
1061                        let new_ub_predicate = predicate!(domain_id <= input_upper_bound + 1);
1062                        let not_equal_predicate = predicate!(domain_id != input_upper_bound + 1);
1063                        reason_buffer.extend(std::iter::once(new_ub_predicate));
1064                        reason_buffer.extend(std::iter::once(not_equal_predicate));
1065                    }
1066                }
1067                (PredicateType::UpperBound, PredicateType::NotEqual) => {
1068                    let trail_upper_bound = trail_entry.predicate.get_right_hand_side();
1069                    let not_equal_constant = predicate.get_right_hand_side();
1070                    // The input predicate is a not equal predicate, and the trail predicate is
1071                    // an upper bound predicate. This is only possible when the upper bound was
1072                    // pushed below the not equals value. Otherwise the hole would have been
1073                    // explicitly placed on the trail and we would have found it earlier.
1074                    pumpkin_assert_simple!(not_equal_constant > trail_upper_bound);
1075
1076                    // The bound was set past the not equals, so we can safely returns the trail
1077                    // reason. todo: can do lifting here.
1078                    reason_buffer.extend(std::iter::once(trail_entry.predicate));
1079                }
1080                (PredicateType::UpperBound, PredicateType::Equal) => {
1081                    let domain_id = predicate.get_domain();
1082                    let equality_constant = predicate.get_right_hand_side();
1083                    // The input predicate is an equality predicate, and the trail predicate
1084                    // is an upper bound predicate. This means that the time of posting the
1085                    // trail predicate is when the input predicate became true.
1086
1087                    // Note that the input equality constant does _not_ necessarily equal
1088                    // the trail upper bound. This would be the
1089                    // case when the the trail upper bound is greater than the input equality
1090                    // constant, but due to holes in the domain, the upper bound got lowered
1091                    // to just the value of the equality constant.
1092                    // For example, x = {1, 2, 3, 8, 15}, setting [x <= 12] would lower the
1093                    // upper bound to x <= 8.
1094
1095                    // Note that it could be that one of the two predicates are decision
1096                    // predicates, so we need to use the substitute functions.
1097
1098                    let predicate_lb = predicate!(domain_id >= equality_constant);
1099                    let predicate_ub = predicate!(domain_id <= equality_constant);
1100                    reason_buffer.extend(std::iter::once(predicate_lb));
1101                    reason_buffer.extend(std::iter::once(predicate_ub));
1102                }
1103                (PredicateType::NotEqual, PredicateType::LowerBound) => {
1104                    let not_equal_constant = trail_entry.predicate.get_right_hand_side();
1105                    let domain_id = predicate.get_domain();
1106                    let input_lower_bound = predicate.get_right_hand_side();
1107                    // The trail predicate is not equals, but the input predicate is a lower
1108                    // bound predicate. This means that creating the hole in the domain resulted
1109                    // in raising the lower bound.
1110
1111                    // I think this holds. The not_equals_constant cannot be greater, since that
1112                    // would not impact the lower bound. It can also not be the same, since
1113                    // creating a hole cannot result in the lower bound being raised to the
1114                    // hole, there must be some other reason for that to happen, which we would
1115                    // find earlier.
1116                    pumpkin_assert_simple!(input_lower_bound > not_equal_constant);
1117
1118                    // The reason for the input predicate [x >= a] is computed recursively as
1119                    // the reason for [x >= a - 1] & [x != a - 1].
1120                    let new_lb_predicate = predicate!(domain_id >= input_lower_bound - 1);
1121                    let new_not_equals_predicate = predicate!(domain_id != input_lower_bound - 1);
1122
1123                    reason_buffer.extend(std::iter::once(new_lb_predicate));
1124                    reason_buffer.extend(std::iter::once(new_not_equals_predicate));
1125                }
1126                (PredicateType::NotEqual, PredicateType::UpperBound) => {
1127                    let not_equal_constant = trail_entry.predicate.get_right_hand_side();
1128                    let domain_id = predicate.get_domain();
1129                    let input_upper_bound = predicate.get_right_hand_side();
1130                    // The trail predicate is not equals, but the input predicate is an upper
1131                    // bound predicate. This means that creating the hole in the domain resulted
1132                    // in lower the upper bound.
1133
1134                    // I think this holds. The not_equals_constant cannot be smaller, since that
1135                    // would not impact the upper bound. It can also not be the same, since
1136                    // creating a hole cannot result in the upper bound being lower to the
1137                    // hole, there must be some other reason for that to happen, which we would
1138                    // find earlier.
1139                    pumpkin_assert_simple!(input_upper_bound < not_equal_constant);
1140
1141                    // The reason for the input predicate [x <= a] is computed recursively as
1142                    // the reason for [x <= a + 1] & [x != a + 1].
1143                    let new_ub_predicate = predicate!(domain_id <= input_upper_bound + 1);
1144                    let new_not_equals_predicate = predicate!(domain_id != input_upper_bound + 1);
1145
1146                    reason_buffer.extend(std::iter::once(new_ub_predicate));
1147                    reason_buffer.extend(std::iter::once(new_not_equals_predicate));
1148                }
1149                (PredicateType::NotEqual, PredicateType::Equal) => {
1150                    let domain_id = predicate.get_domain();
1151                    let equality_constant = predicate.get_right_hand_side();
1152                    // The trail predicate is not equals, but the input predicate is
1153                    // equals. The only time this could is when the not equals forces the
1154                    // lower/upper bounds to meet. So we simply look for the reasons for those
1155                    // bounds recursively.
1156
1157                    // Note that it could be that one of the two predicates are decision
1158                    // predicates, so we need to use the substitute functions.
1159
1160                    let predicate_lb = predicate!(domain_id >= equality_constant);
1161                    let predicate_ub = predicate!(domain_id <= equality_constant);
1162
1163                    reason_buffer.extend(std::iter::once(predicate_lb));
1164                    reason_buffer.extend(std::iter::once(predicate_ub));
1165                }
1166                (
1167                    PredicateType::Equal,
1168                    PredicateType::LowerBound | PredicateType::UpperBound | PredicateType::NotEqual,
1169                ) => {
1170                    // The trail predicate is equality, but the input predicate is either a
1171                    // lower-bound, upper-bound, or not equals.
1172                    //
1173                    // TODO: could consider lifting here
1174                    reason_buffer.extend(std::iter::once(trail_entry.predicate))
1175                }
1176                _ => unreachable!(
1177                    "Unreachable combination of {} and {}",
1178                    trail_entry.predicate, predicate
1179                ),
1180            };
1181            None
1182        }
1183    }
1184}
1185
1186impl State {
1187    pub fn get_domains(&mut self) -> Domains<'_> {
1188        Domains::new(&self.assignments, &mut self.trailed_values)
1189    }
1190
1191    pub fn get_propagation_context(&mut self) -> PropagationContext<'_> {
1192        PropagationContext::new(
1193            &mut self.trailed_values,
1194            &mut self.assignments,
1195            &mut self.reason_store,
1196            &mut self.notification_engine,
1197            PropagatorId(0),
1198        )
1199    }
1200}
1201
1202#[cfg(test)]
1203mod tests {
1204    use crate::conjunction;
1205    use crate::containers::StorageKey;
1206    use crate::declare_inference_label;
1207    use crate::predicate;
1208    use crate::proof::InferenceCode;
1209    use crate::state::CurrentNogood;
1210    use crate::state::PropagatorId;
1211    use crate::state::State;
1212
1213    declare_inference_label!(TestLabel);
1214
1215    #[test]
1216    fn reason_correct_after_creation_variable() {
1217        let mut state = State::default();
1218
1219        let y = state.new_interval_variable(0, 10, None);
1220        let x = state.new_interval_variable(0, 10, None);
1221
1222        let tag = state.new_constraint_tag();
1223        let result = state.post_with_reason(
1224            predicate!(x >= 5),
1225            conjunction!([y >= 5]),
1226            InferenceCode::new(tag, TestLabel),
1227            PropagatorId::create_from_index(0),
1228        );
1229
1230        assert_eq!(result, Ok(()));
1231
1232        let mut buffer = vec![];
1233        let _ =
1234            state.get_propagation_reason(predicate!(x >= 5), &mut buffer, CurrentNogood::empty());
1235
1236        assert_eq!(buffer, vec![predicate!(y >= 5)])
1237    }
1238}