pumpkin_checking/
variable_state.rs1use std::collections::BTreeSet;
2
3use fnv::FnvHashMap;
4
5use crate::AtomicConstraint;
6use crate::Comparison;
7#[cfg(doc)]
8use crate::InferenceChecker;
9use crate::IntExt;
10
11#[derive(Clone, Debug)]
18pub struct VariableState<Atomic: AtomicConstraint> {
19 domains: FnvHashMap<Atomic::Identifier, Domain>,
20}
21
22impl<Atomic: AtomicConstraint> Default for VariableState<Atomic> {
23 fn default() -> Self {
24 Self {
25 domains: Default::default(),
26 }
27 }
28}
29
30impl<Atomic> VariableState<Atomic>
31where
32 Atomic: AtomicConstraint,
33{
34 pub fn prepare_for_conflict_check(
42 premises: impl IntoIterator<Item = Atomic>,
43 consequent: Option<Atomic>,
44 ) -> Result<Self, Atomic::Identifier> {
45 let mut variable_state = VariableState::default();
46
47 let negated_consequent = consequent.as_ref().map(AtomicConstraint::negate);
48
49 if let Some(premise) = premises
51 .into_iter()
52 .chain(negated_consequent)
53 .find(|premise| !variable_state.apply(premise))
54 {
55 return Err(premise.identifier());
56 }
57
58 Ok(variable_state)
59 }
60
61 pub fn domains<'this>(&'this self) -> impl Iterator<Item = &'this Atomic::Identifier> + 'this
63 where
64 Atomic::Identifier: 'this,
65 {
66 self.domains.keys()
67 }
68
69 pub fn lower_bound(&self, identifier: &Atomic::Identifier) -> IntExt {
71 self.domains
72 .get(identifier)
73 .map(|domain| domain.lower_bound)
74 .unwrap_or(IntExt::NegativeInf)
75 }
76
77 pub fn upper_bound(&self, identifier: &Atomic::Identifier) -> IntExt {
79 self.domains
80 .get(identifier)
81 .map(|domain| domain.upper_bound)
82 .unwrap_or(IntExt::PositiveInf)
83 }
84
85 pub fn contains(&self, identifier: &Atomic::Identifier, value: i32) -> bool {
87 self.domains
88 .get(identifier)
89 .map(|domain| {
90 value >= domain.lower_bound
91 && value <= domain.upper_bound
92 && !domain.holes.contains(&value)
93 })
94 .unwrap_or(true)
95 }
96
97 pub fn holes<'a>(&'a self, identifier: &Atomic::Identifier) -> impl Iterator<Item = i32> + 'a
99 where
100 Atomic::Identifier: 'a,
101 {
102 self.domains
103 .get(identifier)
104 .map(|domain| domain.holes.iter().copied())
105 .into_iter()
106 .flatten()
107 }
108
109 pub fn fixed_value(&self, identifier: &Atomic::Identifier) -> Option<i32> {
111 let domain = self.domains.get(identifier)?;
112
113 if domain.lower_bound == domain.upper_bound {
114 let IntExt::Int(value) = domain.lower_bound else {
115 panic!(
116 "lower can only equal upper if they are integers, otherwise the sign of infinity makes them different"
117 );
118 };
119
120 Some(value)
121 } else {
122 None
123 }
124 }
125
126 pub fn iter_domain<'a>(&'a self, identifier: &Atomic::Identifier) -> Option<DomainIterator<'a>>
130 where
131 Atomic::Identifier: 'a,
132 {
133 let domain = self.domains.get(identifier)?;
134
135 let IntExt::Int(lower_bound) = domain.lower_bound else {
136 return None;
138 };
139
140 if !matches!(domain.upper_bound, IntExt::Int(_)) {
142 return None;
143 }
144
145 Some(DomainIterator {
146 domain,
147 next_value: lower_bound,
148 })
149 }
150
151 pub fn apply(&mut self, atomic: &Atomic) -> bool {
156 let identifier = atomic.identifier();
157 let domain = self
158 .domains
159 .entry(identifier)
160 .or_insert(Domain::all_integers());
161
162 match atomic.comparison() {
163 Comparison::GreaterEqual => {
164 domain.tighten_lower_bound(atomic.value());
165 }
166
167 Comparison::LessEqual => {
168 domain.tighten_upper_bound(atomic.value());
169 }
170
171 Comparison::Equal => {
172 domain.tighten_lower_bound(atomic.value());
173 domain.tighten_upper_bound(atomic.value());
174 }
175
176 Comparison::NotEqual => {
177 if domain.lower_bound == atomic.value() {
178 domain.tighten_lower_bound(atomic.value() + 1);
179 }
180
181 if domain.upper_bound == atomic.value() {
182 domain.tighten_upper_bound(atomic.value() - 1);
183 }
184
185 if domain.lower_bound < atomic.value() && domain.upper_bound > atomic.value() {
186 let _ = domain.holes.insert(atomic.value());
187 }
188 }
189 }
190
191 domain.is_consistent()
192 }
193
194 pub fn is_true(&self, atomic: &Atomic) -> bool {
196 let Some(domain) = self.domains.get(&atomic.identifier()) else {
197 return false;
198 };
199
200 match atomic.comparison() {
201 Comparison::GreaterEqual => domain.lower_bound >= atomic.value(),
202
203 Comparison::LessEqual => domain.upper_bound <= atomic.value(),
204
205 Comparison::Equal => {
206 domain.lower_bound >= atomic.value() && domain.upper_bound <= atomic.value()
207 }
208
209 Comparison::NotEqual => {
210 if domain.lower_bound >= atomic.value() {
211 return true;
212 }
213
214 if domain.upper_bound <= atomic.value() {
215 return true;
216 }
217
218 if domain.holes.contains(&atomic.value()) {
219 return true;
220 }
221
222 false
223 }
224 }
225 }
226}
227
228#[derive(Clone, Debug)]
230pub struct Domain {
231 lower_bound: IntExt,
232 upper_bound: IntExt,
233 holes: BTreeSet<i32>,
234}
235
236impl Domain {
237 pub fn all_integers() -> Domain {
239 Domain {
240 lower_bound: IntExt::NegativeInf,
241 upper_bound: IntExt::PositiveInf,
242 holes: BTreeSet::default(),
243 }
244 }
245
246 pub fn empty() -> Domain {
248 Domain {
249 lower_bound: IntExt::PositiveInf,
250 upper_bound: IntExt::NegativeInf,
251 holes: BTreeSet::default(),
252 }
253 }
254
255 pub fn new(lower_bound: IntExt, upper_bound: IntExt, holes: BTreeSet<i32>) -> Self {
257 let mut domain = Domain::all_integers();
258 domain.holes = holes;
259
260 if let IntExt::Int(bound) = lower_bound {
261 domain.tighten_lower_bound(bound);
262 }
263
264 if let IntExt::Int(bound) = upper_bound {
265 domain.tighten_upper_bound(bound);
266 }
267
268 domain
269 }
270
271 pub fn holes(&self) -> &BTreeSet<i32> {
273 &self.holes
274 }
275
276 pub fn lower_bound(&self) -> IntExt {
278 self.lower_bound
279 }
280
281 pub fn upper_bound(&self) -> IntExt {
283 self.upper_bound
284 }
285
286 fn tighten_lower_bound(&mut self, bound: i32) {
289 if self.lower_bound >= bound && !self.holes.contains(&bound) {
290 return;
291 }
292
293 self.lower_bound = IntExt::Int(bound);
294 self.holes = self.holes.split_off(&bound);
295
296 if self.holes.contains(&bound) {
298 self.tighten_lower_bound(bound + 1);
299 }
300 }
301
302 fn tighten_upper_bound(&mut self, bound: i32) {
305 if self.upper_bound <= bound && !self.holes.contains(&bound) {
306 return;
307 }
308
309 self.upper_bound = IntExt::Int(bound);
310
311 let _ = self.holes.split_off(&(bound + 1));
314
315 if self.holes.contains(&bound) {
317 self.tighten_upper_bound(bound - 1);
318 }
319 }
320
321 pub fn is_consistent(&self) -> bool {
323 self.lower_bound <= self.upper_bound
327 }
328}
329
330#[derive(Debug)]
332pub struct DomainIterator<'a> {
333 domain: &'a Domain,
334 next_value: i32,
335}
336
337impl Iterator for DomainIterator<'_> {
338 type Item = i32;
339
340 fn next(&mut self) -> Option<Self::Item> {
341 let DomainIterator { domain, next_value } = self;
342
343 let IntExt::Int(upper_bound) = domain.upper_bound else {
344 panic!("Only finite domains can be iterated.")
345 };
346
347 loop {
348 if *next_value > upper_bound {
350 return None;
351 }
352
353 let value = *next_value;
354 *next_value += 1;
355
356 if domain.holes.contains(&value) {
358 continue;
359 }
360
361 return Some(value);
363 }
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use crate::TestAtomic;
371
372 #[test]
373 fn domain_iterator_unbounded() {
374 let state = VariableState::<TestAtomic>::default();
375 let iterator = state.iter_domain(&"x1");
376
377 assert!(iterator.is_none());
378 }
379
380 #[test]
381 fn domain_iterator_unbounded_lower_bound() {
382 let mut state = VariableState::default();
383
384 let _ = state.apply(&TestAtomic {
385 name: "x1",
386 comparison: Comparison::LessEqual,
387 value: 5,
388 });
389
390 let iterator = state.iter_domain(&"x1");
391
392 assert!(iterator.is_none());
393 }
394
395 #[test]
396 fn domain_iterator_unbounded_upper_bound() {
397 let mut state = VariableState::default();
398
399 let _ = state.apply(&TestAtomic {
400 name: "x1",
401 comparison: Comparison::GreaterEqual,
402 value: 5,
403 });
404
405 let iterator = state.iter_domain(&"x1");
406
407 assert!(iterator.is_none());
408 }
409
410 #[test]
411 fn domain_iterator_bounded_no_holes() {
412 let mut state = VariableState::default();
413
414 let _ = state.apply(&TestAtomic {
415 name: "x1",
416 comparison: Comparison::GreaterEqual,
417 value: 5,
418 });
419
420 let _ = state.apply(&TestAtomic {
421 name: "x1",
422 comparison: Comparison::LessEqual,
423 value: 10,
424 });
425
426 let values = state
427 .iter_domain(&"x1")
428 .expect("the domain is bounded")
429 .collect::<Vec<_>>();
430
431 assert_eq!(values, vec![5, 6, 7, 8, 9, 10]);
432 }
433
434 #[test]
435 fn domain_iterator_bounded_with_holes() {
436 let mut state = VariableState::default();
437
438 let _ = state.apply(&TestAtomic {
439 name: "x1",
440 comparison: Comparison::GreaterEqual,
441 value: 5,
442 });
443
444 let _ = state.apply(&TestAtomic {
445 name: "x1",
446 comparison: Comparison::NotEqual,
447 value: 7,
448 });
449
450 let _ = state.apply(&TestAtomic {
451 name: "x1",
452 comparison: Comparison::LessEqual,
453 value: 10,
454 });
455
456 let values = state
457 .iter_domain(&"x1")
458 .expect("the domain is bounded")
459 .collect::<Vec<_>>();
460
461 assert_eq!(values, vec![5, 6, 8, 9, 10]);
462 }
463}