1use std::fmt::Debug;
4
5use pumpkin_checking::InferenceChecker;
6
7use super::PropagatorQueue;
8use crate::containers::KeyGenerator;
9use crate::engine::EmptyDomain;
10use crate::engine::State;
11use crate::engine::predicates::predicate::Predicate;
12use crate::engine::variables::DomainId;
13use crate::engine::variables::IntegerVariable;
14use crate::engine::variables::Literal;
15use crate::options::LearningOptions;
16use crate::predicate;
17use crate::predicates::PropositionalConjunction;
18use crate::proof::ConstraintTag;
19use crate::proof::InferenceCode;
20use crate::propagation::EnqueueDecision;
21use crate::propagation::ExplanationContext;
22use crate::propagation::NotificationContext;
23use crate::propagation::PropagationContext;
24use crate::propagation::PropagatorConstructor;
25use crate::propagation::PropagatorId;
26use crate::propagators::nogoods::NogoodPropagator;
27use crate::propagators::nogoods::NogoodPropagatorConstructor;
28use crate::state::Conflict;
29use crate::state::PropagatorHandle;
30
31#[derive(Debug)]
33pub struct TestSolver {
34 pub state: State,
35 constraint_tags: KeyGenerator<ConstraintTag>,
36 pub nogood_handle: PropagatorHandle<NogoodPropagator>,
37}
38
39impl Default for TestSolver {
40 fn default() -> Self {
41 let mut state = State::default();
42 let handle = state.add_propagator(NogoodPropagatorConstructor::new(
43 0,
44 LearningOptions::default(),
45 ));
46 let mut solver = Self {
47 state,
48 constraint_tags: Default::default(),
49 nogood_handle: handle,
50 };
51 solver.state.notification_engine.grow();
53 solver
54 }
55}
56
57#[deprecated = "Will be replaced by the state API"]
58impl TestSolver {
59 pub fn accept_inferences_by(&mut self, inference_code: InferenceCode) {
60 #[derive(Debug, Clone, Copy)]
61 struct Checker;
62
63 impl InferenceChecker<Predicate> for Checker {
64 fn check(
65 &self,
66 _: pumpkin_checking::VariableState<Predicate>,
67 _: &[Predicate],
68 _: Option<&Predicate>,
69 ) -> bool {
70 true
71 }
72 }
73
74 self.state
75 .add_inference_checker(inference_code, Box::new(Checker));
76 }
77
78 pub fn new_variable(&mut self, lb: i32, ub: i32) -> DomainId {
79 self.state.new_interval_variable(lb, ub, None)
80 }
81
82 pub fn new_sparse_variable(&mut self, values: Vec<i32>) -> DomainId {
83 self.state.new_sparse_variable(values, None)
84 }
85
86 pub fn new_literal(&mut self) -> Literal {
87 let domain_id = self.new_variable(0, 1);
88 Literal::new(domain_id)
89 }
90
91 pub fn new_propagator<Constructor>(
92 &mut self,
93 constructor: Constructor,
94 ) -> Result<PropagatorId, Conflict>
95 where
96 Constructor: PropagatorConstructor,
97 Constructor::PropagatorImpl: 'static,
98 {
99 let handle = self.state.add_propagator(constructor);
100 self.state
101 .propagate_to_fixed_point()
102 .map(|_| handle.propagator_id())
103 }
104
105 pub fn contains<Var: IntegerVariable>(&self, var: Var, value: i32) -> bool {
106 var.contains(&self.state.assignments, value)
107 }
108
109 pub fn lower_bound(&self, var: DomainId) -> i32 {
110 self.state.assignments.get_lower_bound(var)
111 }
112
113 pub fn remove_and_notify(
114 &mut self,
115 propagator: PropagatorId,
116 var: DomainId,
117 value: i32,
118 ) -> EnqueueDecision {
119 let result = self.state.post(predicate!(var != value));
120 assert!(
121 result.is_ok(),
122 "The provided value to `increase_lower_bound` caused an empty domain, generally the propagator should not be notified of this change!"
123 );
124 let mut propagator_queue = PropagatorQueue::new(4);
125 #[allow(deprecated, reason = "Will be refactored in the future")]
126 self.state
127 .notification_engine
128 .notify_propagators_about_domain_events_test(
129 &mut self.state.assignments,
130 &mut self.state.trailed_values,
131 &mut self.state.propagators,
132 &mut propagator_queue,
133 );
134 if propagator_queue.is_propagator_enqueued(propagator) {
135 EnqueueDecision::Enqueue
136 } else {
137 EnqueueDecision::Skip
138 }
139 }
140
141 pub fn increase_lower_bound_and_notify(
142 &mut self,
143 propagator: PropagatorId,
144 _local_id: u32,
145 var: DomainId,
146 value: i32,
147 ) -> EnqueueDecision {
148 let result = self.state.post(predicate!(var >= value));
149 assert!(
150 result.is_ok(),
151 "The provided value to `increase_lower_bound` caused an empty domain, generally the propagator should not be notified of this change!"
152 );
153 let mut propagator_queue = PropagatorQueue::new(4);
154 #[allow(deprecated, reason = "Will be refactored in the future")]
155 self.state
156 .notification_engine
157 .notify_propagators_about_domain_events_test(
158 &mut self.state.assignments,
159 &mut self.state.trailed_values,
160 &mut self.state.propagators,
161 &mut propagator_queue,
162 );
163 if propagator_queue.is_propagator_enqueued(propagator) {
164 EnqueueDecision::Enqueue
165 } else {
166 EnqueueDecision::Skip
167 }
168 }
169
170 pub fn decrease_upper_bound_and_notify(
171 &mut self,
172 propagator: PropagatorId,
173 _local_id: u32,
174 var: DomainId,
175 value: i32,
176 ) -> EnqueueDecision {
177 let result = self.state.post(predicate!(var <= value));
178 assert!(
179 result.is_ok(),
180 "The provided value to `increase_lower_bound` caused an empty domain, generally the propagator should not be notified of this change!"
181 );
182 let mut propagator_queue = PropagatorQueue::new(4);
183 #[allow(deprecated, reason = "Will be refactored in the future")]
184 self.state
185 .notification_engine
186 .notify_propagators_about_domain_events_test(
187 &mut self.state.assignments,
188 &mut self.state.trailed_values,
189 &mut self.state.propagators,
190 &mut propagator_queue,
191 );
192 if propagator_queue.is_propagator_enqueued(propagator) {
193 EnqueueDecision::Enqueue
194 } else {
195 EnqueueDecision::Skip
196 }
197 }
198
199 pub fn is_literal_false(&self, literal: Literal) -> bool {
200 self.state
201 .assignments
202 .evaluate_predicate(literal.get_true_predicate())
203 .is_some_and(|truth_value| !truth_value)
204 }
205
206 pub fn upper_bound(&self, var: DomainId) -> i32 {
207 self.state.assignments.get_upper_bound(var)
208 }
209
210 pub fn remove(&mut self, var: DomainId, value: i32) -> Result<(), EmptyDomain> {
211 let _ = self.state.post(predicate!(var != value))?;
212
213 Ok(())
214 }
215
216 pub fn set_literal(&mut self, literal: Literal, truth_value: bool) -> Result<(), EmptyDomain> {
217 let _ = match truth_value {
218 true => self.state.assignments.post_predicate(
219 literal.get_true_predicate(),
220 None,
221 &mut self.state.notification_engine,
222 )?,
223 false => self.state.assignments.post_predicate(
224 (!literal).get_true_predicate(),
225 None,
226 &mut self.state.notification_engine,
227 )?,
228 };
229
230 Ok(())
231 }
232
233 pub fn propagate(&mut self, propagator: PropagatorId) -> Result<(), Conflict> {
234 let context = PropagationContext::new(
235 &mut self.state.trailed_values,
236 &mut self.state.assignments,
237 &mut self.state.reason_store,
238 &mut self.state.notification_engine,
239 propagator,
240 );
241 self.state.propagators[propagator].propagate(context)
242 }
243
244 pub fn propagate_until_fixed_point(
245 &mut self,
246 propagator: PropagatorId,
247 ) -> Result<(), Conflict> {
248 let mut num_trail_entries = self.state.assignments.num_trail_entries();
249 self.notify_propagator(propagator);
250 loop {
251 {
252 let context = PropagationContext::new(
254 &mut self.state.trailed_values,
255 &mut self.state.assignments,
256 &mut self.state.reason_store,
257 &mut self.state.notification_engine,
258 propagator,
259 );
260 self.state.propagators[propagator].propagate(context)?;
261 self.notify_propagator(propagator);
262 }
263 if self.state.assignments.num_trail_entries() == num_trail_entries {
264 break;
265 }
266 num_trail_entries = self.state.assignments.num_trail_entries();
267 }
268 Ok(())
269 }
270
271 pub fn notify_propagator(&mut self, _propagator: PropagatorId) {
272 #[allow(deprecated, reason = "Will be refactored in the future")]
273 self.state
274 .notification_engine
275 .notify_propagators_about_domain_events_test(
276 &mut self.state.assignments,
277 &mut self.state.trailed_values,
278 &mut self.state.propagators,
279 &mut PropagatorQueue::new(4),
280 );
281 }
282
283 pub fn get_reason_int(&mut self, predicate: Predicate) -> PropositionalConjunction {
284 #[allow(deprecated, reason = "Will be refactored in the future")]
285 let reason_ref = self
286 .state
287 .assignments
288 .get_reason_for_predicate_brute_force(predicate);
289 let mut predicates = vec![];
290 let _ = self.state.reason_store.get_or_compute(
291 reason_ref,
292 ExplanationContext::without_working_nogood(
293 &self.state.assignments,
294 self.state
295 .assignments
296 .get_trail_position(&predicate)
297 .unwrap(),
298 &mut self.state.notification_engine,
299 ),
300 &mut self.state.propagators,
301 &mut predicates,
302 );
303
304 PropositionalConjunction::from(predicates)
305 }
306
307 pub fn get_reason_bool(
308 &mut self,
309 literal: Literal,
310 truth_value: bool,
311 ) -> PropositionalConjunction {
312 let predicate = match truth_value {
313 true => literal.get_true_predicate(),
314 false => (!literal).get_true_predicate(),
315 };
316 self.get_reason_int(predicate)
317 }
318
319 pub fn assert_bounds(&self, var: DomainId, lb: i32, ub: i32) {
320 let actual_lb = self.lower_bound(var);
321 let actual_ub = self.upper_bound(var);
322
323 assert_eq!(
324 (lb, ub),
325 (actual_lb, actual_ub),
326 "The expected bounds [{lb}..{ub}] did not match the actual bounds [{actual_lb}..{actual_ub}]"
327 );
328 }
329
330 pub fn new_constraint_tag(&mut self) -> ConstraintTag {
331 self.constraint_tags.next_key()
332 }
333
334 pub fn new_checkpoint(&mut self) {
335 self.state.new_checkpoint();
336 }
337
338 pub fn synchronise(&mut self, level: usize) {
339 let _ = self
340 .state
341 .assignments
342 .synchronise(level, &mut self.state.notification_engine);
343 self.state.notification_engine.synchronise(
344 level,
345 &self.state.assignments,
346 &mut self.state.trailed_values,
347 );
348 self.state.trailed_values.synchronise(level);
349
350 for propagator in self.state.propagators.iter_propagators_mut() {
351 let mut context =
352 NotificationContext::new(&mut self.state.trailed_values, &self.state.assignments);
353
354 propagator.synchronise(context.reborrow());
355 }
356 }
357}