Skip to main content

pumpkin_core/propagation/
constructor.rs

1use std::ops::Deref;
2use std::ops::DerefMut;
3
4use pumpkin_checking::InferenceChecker;
5
6use super::Domains;
7use super::LocalId;
8use super::Propagator;
9use super::PropagatorId;
10use super::PropagatorVarId;
11#[cfg(doc)]
12use crate::Solver;
13use crate::basic_types::PredicateId;
14use crate::basic_types::RefOrOwned;
15use crate::engine::Assignments;
16use crate::engine::State;
17use crate::engine::TrailedValues;
18use crate::engine::notifications::Watchers;
19#[cfg(doc)]
20use crate::engine::variables::AffineView;
21#[cfg(doc)]
22use crate::engine::variables::DomainId;
23use crate::predicates::Predicate;
24use crate::proof::InferenceCode;
25#[cfg(doc)]
26use crate::propagation::DomainEvent;
27use crate::propagation::DomainEvents;
28use crate::propagators::reified_propagator::ReifiedChecker;
29use crate::variables::IntegerVariable;
30use crate::variables::Literal;
31
32/// A propagator constructor creates a fully initialized instance of a [`Propagator`].
33///
34/// The constructor is responsible for:
35/// 1) Indicating on which [`DomainEvent`]s the propagator should be enqueued (via the
36///    [`PropagatorConstructorContext`]).
37/// 2) Initialising the [`PropagatorConstructor::PropagatorImpl`] and its structures.
38pub trait PropagatorConstructor {
39    /// The propagator that is produced by this constructor.
40    type PropagatorImpl: Propagator + Clone;
41
42    /// Add inference checkers to the solver if applicable.
43    ///
44    /// If the `check-propagations` feature is turned on, then the inference checker will be used
45    /// to verify the propagations done by this propagator are correct.
46    ///
47    /// See [`InferenceChecker`] for more information.
48    fn add_inference_checkers(&self, _checkers: InferenceCheckers<'_>) {}
49
50    /// Create the propagator instance from `Self`.
51    fn create(self, context: PropagatorConstructorContext) -> Self::PropagatorImpl;
52}
53
54/// Interface used to add [`InferenceChecker`]s to the [`State`].
55#[derive(Debug)]
56pub struct InferenceCheckers<'state> {
57    state: &'state mut State,
58    reification_literal: Option<Literal>,
59}
60
61impl<'state> InferenceCheckers<'state> {
62    #[cfg(feature = "check-propagations")]
63    pub(crate) fn new(state: &'state mut State) -> Self {
64        InferenceCheckers {
65            state,
66            reification_literal: None,
67        }
68    }
69}
70
71impl InferenceCheckers<'_> {
72    /// Forwards to [`State::add_inference_checker`].
73    pub fn add_inference_checker(
74        &mut self,
75        inference_code: InferenceCode,
76        checker: Box<dyn InferenceChecker<Predicate>>,
77    ) {
78        if let Some(reification_literal) = self.reification_literal {
79            let reification_checker = ReifiedChecker {
80                inner: checker.into(),
81                reification_literal,
82            };
83            self.state
84                .add_inference_checker(inference_code, Box::new(reification_checker));
85        } else {
86            self.state.add_inference_checker(inference_code, checker);
87        }
88    }
89
90    pub fn with_reification_literal(&mut self, literal: Literal) {
91        self.reification_literal = Some(literal)
92    }
93}
94
95/// [`PropagatorConstructorContext`] is used when [`Propagator`]s are initialised after creation.
96///
97/// It represents a communication point between the [`Solver`] and the [`Propagator`].
98/// Propagators use the [`PropagatorConstructorContext`] to register to domain changes
99/// of variables and to retrieve the current bounds of variables.
100#[derive(Debug)]
101pub struct PropagatorConstructorContext<'a> {
102    state: &'a mut State,
103    pub(crate) propagator_id: PropagatorId,
104
105    /// A [`LocalId`] that is guaranteed not to be used to register any variables yet. This is
106    /// either a reference or an owned value, to support
107    /// [`PropagatorConstructorContext::reborrow`].
108    next_local_id: RefOrOwned<'a, LocalId>,
109
110    /// Marker to indicate whether the constructor registered for at least one domain event or
111    /// predicate becoming assigned. If not, the [`Drop`] implementation will cause a panic.
112    did_register: RefOrOwned<'a, bool>,
113}
114
115impl PropagatorConstructorContext<'_> {
116    pub(crate) fn new<'a>(
117        propagator_id: PropagatorId,
118        state: &'a mut State,
119    ) -> PropagatorConstructorContext<'a> {
120        PropagatorConstructorContext {
121            next_local_id: RefOrOwned::Owned(LocalId::from(0)),
122            propagator_id,
123            state,
124            did_register: RefOrOwned::Owned(false),
125        }
126    }
127
128    /// Indicate that the constructor is deliberately not registering the propagator to be enqueued
129    /// at any time.
130    ///
131    /// If this is called and later a registration happens, then the registration will still go
132    /// through. Calling this function only prevents the crash if no registration happens.
133    pub fn will_not_register_any_events(&mut self) {
134        *self.did_register = true;
135    }
136
137    /// Get domain information.
138    pub fn domains(&mut self) -> Domains<'_> {
139        Domains::new(&self.state.assignments, &mut self.state.trailed_values)
140    }
141
142    /// Subscribes the propagator to the given [`DomainEvents`].
143    ///
144    /// The domain events determine when [`Propagator::notify()`] will be called on the propagator.
145    /// The [`LocalId`] is internal information related to the propagator,
146    /// which is used when calling [`Propagator::notify()`] to identify the variable.
147    ///
148    /// Each variable *must* have a unique [`LocalId`]. Most often this would be its index of the
149    /// variable in the internal array of variables.
150    ///
151    /// Duplicate registrations are ignored.
152    pub fn register(
153        &mut self,
154        var: impl IntegerVariable,
155        domain_events: DomainEvents,
156        local_id: LocalId,
157    ) {
158        self.will_not_register_any_events();
159
160        let propagator_var = PropagatorVarId {
161            propagator: self.propagator_id,
162            variable: local_id,
163        };
164
165        self.update_next_local_id(local_id);
166
167        let mut watchers = Watchers::new(propagator_var, &mut self.state.notification_engine);
168        var.watch_all(&mut watchers, domain_events.events());
169    }
170
171    /// Register the propagator to be enqueued when the given [`Predicate`] becomes true.
172    /// Returns the [`PredicateId`] used by the solver to track the predicate.
173    pub fn register_predicate(&mut self, predicate: Predicate) -> PredicateId {
174        self.will_not_register_any_events();
175
176        self.state.notification_engine.watch_predicate(
177            predicate,
178            self.propagator_id,
179            &mut self.state.trailed_values,
180            &self.state.assignments,
181        )
182    }
183
184    /// Subscribes the propagator to the given [`DomainEvents`] when they are undone during
185    /// backtracking. This method is complementary to [`PropagatorConstructorContext::register`],
186    /// the [`LocalId`]s provided to both of these method should be the same for the same variable.
187    ///
188    /// The domain events determine when [`Propagator::notify_backtrack()`] will be called on the
189    /// propagator. The [`LocalId`] is internal information related to the propagator,
190    /// which is used when calling [`Propagator::notify_backtrack()`] to identify the variable.
191    ///
192    /// Each variable *must* have a unique [`LocalId`]. Most often this would be its index of the
193    /// variable in the internal array of variables.
194    ///
195    /// Note that the [`LocalId`] is used to differentiate between [`DomainId`]s and
196    /// [`AffineView`]s.
197    pub fn register_backtrack<Var: IntegerVariable>(
198        &mut self,
199        var: Var,
200        domain_events: DomainEvents,
201        local_id: LocalId,
202    ) {
203        let propagator_var = PropagatorVarId {
204            propagator: self.propagator_id,
205            variable: local_id,
206        };
207
208        self.update_next_local_id(local_id);
209
210        let mut watchers = Watchers::new(propagator_var, &mut self.state.notification_engine);
211        var.watch_all_backtrack(&mut watchers, domain_events.events());
212    }
213
214    /// Get a new [`LocalId`] which is guaranteed to be unused.
215    pub(crate) fn get_next_local_id(&self) -> LocalId {
216        *self.next_local_id.deref()
217    }
218
219    /// Reborrow the current context to a new value with a shorter lifetime. Should be used when
220    /// passing `Self` to another function that takes ownership, but the value is still needed
221    /// afterwards.
222    pub fn reborrow(&mut self) -> PropagatorConstructorContext<'_> {
223        PropagatorConstructorContext {
224            propagator_id: self.propagator_id,
225            next_local_id: self.next_local_id.reborrow(),
226            did_register: self.did_register.reborrow(),
227            state: self.state,
228        }
229    }
230
231    /// Add an inference checker for inferences produced by the propagator.
232    ///
233    /// If the `check-propagations` feature is not enabled, adding an [`InferenceChecker`] will not
234    /// do anything.
235    pub fn add_inference_checker(
236        &mut self,
237        inference_code: InferenceCode,
238        checker: Box<dyn InferenceChecker<Predicate>>,
239    ) {
240        self.state.add_inference_checker(inference_code, checker);
241    }
242
243    /// Set the next local id to be at least one more than the largest encountered local id.
244    fn update_next_local_id(&mut self, local_id: LocalId) {
245        let next_local_id = (*self.next_local_id.deref()).max(LocalId::from(local_id.unpack() + 1));
246
247        *self.next_local_id.deref_mut() = next_local_id;
248    }
249}
250
251impl Drop for PropagatorConstructorContext<'_> {
252    fn drop(&mut self) {
253        if std::thread::panicking() {
254            // If we are already unwinding due to a panic, we do not want to trigger another one.
255            return;
256        }
257
258        let did_register = match self.did_register {
259            // If we are in a reborrowed context, we do not want to enforce registration.
260            RefOrOwned::Ref(_) => return,
261
262            RefOrOwned::Owned(did_register) => did_register,
263        };
264
265        if !did_register {
266            panic!(
267                "Propagator did not register to be enqueued. If this is intentional, call PropagatorConstructorContext::will_not_register_any_events()."
268            );
269        }
270    }
271}
272
273mod private {
274    use super::*;
275    use crate::propagation::HasAssignments;
276
277    impl HasAssignments for PropagatorConstructorContext<'_> {
278        fn assignments(&self) -> &Assignments {
279            &self.state.assignments
280        }
281
282        fn trailed_values(&self) -> &TrailedValues {
283            &self.state.trailed_values
284        }
285
286        fn trailed_values_mut(&mut self) -> &mut TrailedValues {
287            &mut self.state.trailed_values
288        }
289    }
290}
291
292#[cfg(test)]
293mod tests {
294
295    use super::*;
296    use crate::variables::DomainId;
297
298    #[test]
299    #[should_panic]
300    fn panic_when_no_registration_happened() {
301        let mut state = State::default();
302        state.notification_engine.grow();
303
304        let _c1 = PropagatorConstructorContext::new(PropagatorId(0), &mut state);
305    }
306
307    #[test]
308    fn do_not_panic_if_told_no_registration_will_happen() {
309        let mut state = State::default();
310        state.notification_engine.grow();
311
312        let mut ctx = PropagatorConstructorContext::new(PropagatorId(0), &mut state);
313        ctx.will_not_register_any_events();
314    }
315
316    #[test]
317    fn do_not_panic_if_no_registration_happens_in_reborrowed() {
318        let mut state = State::default();
319        state.notification_engine.grow();
320
321        let mut ctx = PropagatorConstructorContext::new(PropagatorId(0), &mut state);
322        let ctx2 = ctx.reborrow();
323        drop(ctx2);
324
325        ctx.will_not_register_any_events();
326    }
327
328    #[test]
329    fn reborrowing_remembers_next_local_id() {
330        let mut state = State::default();
331        state.notification_engine.grow();
332
333        let mut c1 = PropagatorConstructorContext::new(PropagatorId(0), &mut state);
334        c1.will_not_register_any_events();
335
336        let mut c2 = c1.reborrow();
337        c2.register(DomainId::new(0), DomainEvents::ANY_INT, LocalId::from(1));
338        drop(c2);
339
340        assert_eq!(LocalId::from(2), c1.get_next_local_id());
341    }
342}