pumpkin_core/propagation/
constructor.rs1use std::ops::Deref;
2use std::ops::DerefMut;
3
4use pumpkin_checking::InferenceChecker;
5
6use super::Domains;
7use super::LocalId;
8use super::Propagator;
9use super::PropagatorId;
10use super::PropagatorVarId;
11#[cfg(doc)]
12use crate::Solver;
13use crate::basic_types::PredicateId;
14use crate::basic_types::RefOrOwned;
15use crate::engine::Assignments;
16use crate::engine::State;
17use crate::engine::TrailedValues;
18use crate::engine::notifications::Watchers;
19#[cfg(doc)]
20use crate::engine::variables::AffineView;
21#[cfg(doc)]
22use crate::engine::variables::DomainId;
23use crate::predicates::Predicate;
24use crate::proof::InferenceCode;
25#[cfg(doc)]
26use crate::propagation::DomainEvent;
27use crate::propagation::DomainEvents;
28use crate::propagators::reified_propagator::ReifiedChecker;
29use crate::variables::IntegerVariable;
30use crate::variables::Literal;
31
32pub trait PropagatorConstructor {
39 type PropagatorImpl: Propagator + Clone;
41
42 fn add_inference_checkers(&self, _checkers: InferenceCheckers<'_>) {}
49
50 fn create(self, context: PropagatorConstructorContext) -> Self::PropagatorImpl;
52}
53
54#[derive(Debug)]
56pub struct InferenceCheckers<'state> {
57 state: &'state mut State,
58 reification_literal: Option<Literal>,
59}
60
61impl<'state> InferenceCheckers<'state> {
62 #[cfg(feature = "check-propagations")]
63 pub(crate) fn new(state: &'state mut State) -> Self {
64 InferenceCheckers {
65 state,
66 reification_literal: None,
67 }
68 }
69}
70
71impl InferenceCheckers<'_> {
72 pub fn add_inference_checker(
74 &mut self,
75 inference_code: InferenceCode,
76 checker: Box<dyn InferenceChecker<Predicate>>,
77 ) {
78 if let Some(reification_literal) = self.reification_literal {
79 let reification_checker = ReifiedChecker {
80 inner: checker.into(),
81 reification_literal,
82 };
83 self.state
84 .add_inference_checker(inference_code, Box::new(reification_checker));
85 } else {
86 self.state.add_inference_checker(inference_code, checker);
87 }
88 }
89
90 pub fn with_reification_literal(&mut self, literal: Literal) {
91 self.reification_literal = Some(literal)
92 }
93}
94
95#[derive(Debug)]
101pub struct PropagatorConstructorContext<'a> {
102 state: &'a mut State,
103 pub(crate) propagator_id: PropagatorId,
104
105 next_local_id: RefOrOwned<'a, LocalId>,
109
110 did_register: RefOrOwned<'a, bool>,
113}
114
115impl PropagatorConstructorContext<'_> {
116 pub(crate) fn new<'a>(
117 propagator_id: PropagatorId,
118 state: &'a mut State,
119 ) -> PropagatorConstructorContext<'a> {
120 PropagatorConstructorContext {
121 next_local_id: RefOrOwned::Owned(LocalId::from(0)),
122 propagator_id,
123 state,
124 did_register: RefOrOwned::Owned(false),
125 }
126 }
127
128 pub fn will_not_register_any_events(&mut self) {
134 *self.did_register = true;
135 }
136
137 pub fn domains(&mut self) -> Domains<'_> {
139 Domains::new(&self.state.assignments, &mut self.state.trailed_values)
140 }
141
142 pub fn register(
153 &mut self,
154 var: impl IntegerVariable,
155 domain_events: DomainEvents,
156 local_id: LocalId,
157 ) {
158 self.will_not_register_any_events();
159
160 let propagator_var = PropagatorVarId {
161 propagator: self.propagator_id,
162 variable: local_id,
163 };
164
165 self.update_next_local_id(local_id);
166
167 let mut watchers = Watchers::new(propagator_var, &mut self.state.notification_engine);
168 var.watch_all(&mut watchers, domain_events.events());
169 }
170
171 pub fn register_predicate(&mut self, predicate: Predicate) -> PredicateId {
174 self.will_not_register_any_events();
175
176 self.state.notification_engine.watch_predicate(
177 predicate,
178 self.propagator_id,
179 &mut self.state.trailed_values,
180 &self.state.assignments,
181 )
182 }
183
184 pub fn register_backtrack<Var: IntegerVariable>(
198 &mut self,
199 var: Var,
200 domain_events: DomainEvents,
201 local_id: LocalId,
202 ) {
203 let propagator_var = PropagatorVarId {
204 propagator: self.propagator_id,
205 variable: local_id,
206 };
207
208 self.update_next_local_id(local_id);
209
210 let mut watchers = Watchers::new(propagator_var, &mut self.state.notification_engine);
211 var.watch_all_backtrack(&mut watchers, domain_events.events());
212 }
213
214 pub(crate) fn get_next_local_id(&self) -> LocalId {
216 *self.next_local_id.deref()
217 }
218
219 pub fn reborrow(&mut self) -> PropagatorConstructorContext<'_> {
223 PropagatorConstructorContext {
224 propagator_id: self.propagator_id,
225 next_local_id: self.next_local_id.reborrow(),
226 did_register: self.did_register.reborrow(),
227 state: self.state,
228 }
229 }
230
231 pub fn add_inference_checker(
236 &mut self,
237 inference_code: InferenceCode,
238 checker: Box<dyn InferenceChecker<Predicate>>,
239 ) {
240 self.state.add_inference_checker(inference_code, checker);
241 }
242
243 fn update_next_local_id(&mut self, local_id: LocalId) {
245 let next_local_id = (*self.next_local_id.deref()).max(LocalId::from(local_id.unpack() + 1));
246
247 *self.next_local_id.deref_mut() = next_local_id;
248 }
249}
250
251impl Drop for PropagatorConstructorContext<'_> {
252 fn drop(&mut self) {
253 if std::thread::panicking() {
254 return;
256 }
257
258 let did_register = match self.did_register {
259 RefOrOwned::Ref(_) => return,
261
262 RefOrOwned::Owned(did_register) => did_register,
263 };
264
265 if !did_register {
266 panic!(
267 "Propagator did not register to be enqueued. If this is intentional, call PropagatorConstructorContext::will_not_register_any_events()."
268 );
269 }
270 }
271}
272
273mod private {
274 use super::*;
275 use crate::propagation::HasAssignments;
276
277 impl HasAssignments for PropagatorConstructorContext<'_> {
278 fn assignments(&self) -> &Assignments {
279 &self.state.assignments
280 }
281
282 fn trailed_values(&self) -> &TrailedValues {
283 &self.state.trailed_values
284 }
285
286 fn trailed_values_mut(&mut self) -> &mut TrailedValues {
287 &mut self.state.trailed_values
288 }
289 }
290}
291
292#[cfg(test)]
293mod tests {
294
295 use super::*;
296 use crate::variables::DomainId;
297
298 #[test]
299 #[should_panic]
300 fn panic_when_no_registration_happened() {
301 let mut state = State::default();
302 state.notification_engine.grow();
303
304 let _c1 = PropagatorConstructorContext::new(PropagatorId(0), &mut state);
305 }
306
307 #[test]
308 fn do_not_panic_if_told_no_registration_will_happen() {
309 let mut state = State::default();
310 state.notification_engine.grow();
311
312 let mut ctx = PropagatorConstructorContext::new(PropagatorId(0), &mut state);
313 ctx.will_not_register_any_events();
314 }
315
316 #[test]
317 fn do_not_panic_if_no_registration_happens_in_reborrowed() {
318 let mut state = State::default();
319 state.notification_engine.grow();
320
321 let mut ctx = PropagatorConstructorContext::new(PropagatorId(0), &mut state);
322 let ctx2 = ctx.reborrow();
323 drop(ctx2);
324
325 ctx.will_not_register_any_events();
326 }
327
328 #[test]
329 fn reborrowing_remembers_next_local_id() {
330 let mut state = State::default();
331 state.notification_engine.grow();
332
333 let mut c1 = PropagatorConstructorContext::new(PropagatorId(0), &mut state);
334 c1.will_not_register_any_events();
335
336 let mut c2 = c1.reborrow();
337 c2.register(DomainId::new(0), DomainEvents::ANY_INT, LocalId::from(1));
338 drop(c2);
339
340 assert_eq!(LocalId::from(2), c1.get_next_local_id());
341 }
342}