Skip to main content

pumpkin_checking/
variable_state.rs

1use std::collections::BTreeSet;
2
3use fnv::FnvHashMap;
4
5use crate::AtomicConstraint;
6use crate::Comparison;
7#[cfg(doc)]
8use crate::InferenceChecker;
9use crate::IntExt;
10
11/// The domains of all variables in the problem.
12///
13/// Domains are initially unbounded. This is why bounds are represented as [`IntExt`].
14///
15/// Domains can be reduced through [`VariableState::apply`]. By default, the domain of every
16/// variable is infinite.
17#[derive(Clone, Debug)]
18pub struct VariableState<Atomic: AtomicConstraint> {
19    domains: FnvHashMap<Atomic::Identifier, Domain>,
20}
21
22impl<Atomic: AtomicConstraint> Default for VariableState<Atomic> {
23    fn default() -> Self {
24        Self {
25            domains: Default::default(),
26        }
27    }
28}
29
30impl<Atomic> VariableState<Atomic>
31where
32    Atomic: AtomicConstraint,
33{
34    /// Create a variable state that applies all the premises and, if present, the negation of the
35    /// consequent.
36    ///
37    /// If `premises /\ !consequent` contain mutually exclusive atomic constraints (e.g., `[x >=
38    /// 5]` and `[x <= 2]`) then `None` is returned.
39    ///
40    /// An [`InferenceChecker`] will receive a [`VariableState`] that conforms to this description.
41    pub fn prepare_for_conflict_check(
42        premises: impl IntoIterator<Item = Atomic>,
43        consequent: Option<Atomic>,
44    ) -> Result<Self, Atomic::Identifier> {
45        let mut variable_state = VariableState::default();
46
47        let negated_consequent = consequent.as_ref().map(AtomicConstraint::negate);
48
49        // Apply all the premises and the negation of the consequent to the state.
50        if let Some(premise) = premises
51            .into_iter()
52            .chain(negated_consequent)
53            .find(|premise| !variable_state.apply(premise))
54        {
55            return Err(premise.identifier());
56        }
57
58        Ok(variable_state)
59    }
60
61    /// The domains for which at least one atomic is applied.
62    pub fn domains<'this>(&'this self) -> impl Iterator<Item = &'this Atomic::Identifier> + 'this
63    where
64        Atomic::Identifier: 'this,
65    {
66        self.domains.keys()
67    }
68
69    /// Get the lower bound of a variable.
70    pub fn lower_bound(&self, identifier: &Atomic::Identifier) -> IntExt {
71        self.domains
72            .get(identifier)
73            .map(|domain| domain.lower_bound)
74            .unwrap_or(IntExt::NegativeInf)
75    }
76
77    /// Get the upper bound of a variable.
78    pub fn upper_bound(&self, identifier: &Atomic::Identifier) -> IntExt {
79        self.domains
80            .get(identifier)
81            .map(|domain| domain.upper_bound)
82            .unwrap_or(IntExt::PositiveInf)
83    }
84
85    /// Tests whether the given value is in the domain of the variable.
86    pub fn contains(&self, identifier: &Atomic::Identifier, value: i32) -> bool {
87        self.domains
88            .get(identifier)
89            .map(|domain| {
90                value >= domain.lower_bound
91                    && value <= domain.upper_bound
92                    && !domain.holes.contains(&value)
93            })
94            .unwrap_or(true)
95    }
96
97    /// Get the holes within the lower and upper bound of the variable expression.
98    pub fn holes<'a>(&'a self, identifier: &Atomic::Identifier) -> impl Iterator<Item = i32> + 'a
99    where
100        Atomic::Identifier: 'a,
101    {
102        self.domains
103            .get(identifier)
104            .map(|domain| domain.holes.iter().copied())
105            .into_iter()
106            .flatten()
107    }
108
109    /// Get the fixed value of this variable, if it is fixed.
110    pub fn fixed_value(&self, identifier: &Atomic::Identifier) -> Option<i32> {
111        let domain = self.domains.get(identifier)?;
112
113        if domain.lower_bound == domain.upper_bound {
114            let IntExt::Int(value) = domain.lower_bound else {
115                panic!(
116                    "lower can only equal upper if they are integers, otherwise the sign of infinity makes them different"
117                );
118            };
119
120            Some(value)
121        } else {
122            None
123        }
124    }
125
126    /// Obtain an iterator over the domain of the variable.
127    ///
128    /// If the domain is unbounded, then `None` is returned.
129    pub fn iter_domain<'a>(&'a self, identifier: &Atomic::Identifier) -> Option<DomainIterator<'a>>
130    where
131        Atomic::Identifier: 'a,
132    {
133        let domain = self.domains.get(identifier)?;
134
135        let IntExt::Int(lower_bound) = domain.lower_bound else {
136            // If there is no lower bound, then the domain is unbounded.
137            return None;
138        };
139
140        // Ensure there is also an upper bound.
141        if !matches!(domain.upper_bound, IntExt::Int(_)) {
142            return None;
143        }
144
145        Some(DomainIterator {
146            domain,
147            next_value: lower_bound,
148        })
149    }
150
151    /// Apply the given `Atomic` to the state.
152    ///
153    /// Returns true if the state remains consistent, or false if the atomic cannot be true in
154    /// conjunction with previously applied atomics.
155    pub fn apply(&mut self, atomic: &Atomic) -> bool {
156        let identifier = atomic.identifier();
157        let domain = self
158            .domains
159            .entry(identifier)
160            .or_insert(Domain::all_integers());
161
162        match atomic.comparison() {
163            Comparison::GreaterEqual => {
164                domain.tighten_lower_bound(atomic.value());
165            }
166
167            Comparison::LessEqual => {
168                domain.tighten_upper_bound(atomic.value());
169            }
170
171            Comparison::Equal => {
172                domain.tighten_lower_bound(atomic.value());
173                domain.tighten_upper_bound(atomic.value());
174            }
175
176            Comparison::NotEqual => {
177                if domain.lower_bound == atomic.value() {
178                    domain.tighten_lower_bound(atomic.value() + 1);
179                }
180
181                if domain.upper_bound == atomic.value() {
182                    domain.tighten_upper_bound(atomic.value() - 1);
183                }
184
185                if domain.lower_bound < atomic.value() && domain.upper_bound > atomic.value() {
186                    let _ = domain.holes.insert(atomic.value());
187                }
188            }
189        }
190
191        domain.is_consistent()
192    }
193
194    /// Is the given atomic true in the current state.
195    pub fn is_true(&self, atomic: &Atomic) -> bool {
196        let Some(domain) = self.domains.get(&atomic.identifier()) else {
197            return false;
198        };
199
200        match atomic.comparison() {
201            Comparison::GreaterEqual => domain.lower_bound >= atomic.value(),
202
203            Comparison::LessEqual => domain.upper_bound <= atomic.value(),
204
205            Comparison::Equal => {
206                domain.lower_bound >= atomic.value() && domain.upper_bound <= atomic.value()
207            }
208
209            Comparison::NotEqual => {
210                if domain.lower_bound >= atomic.value() {
211                    return true;
212                }
213
214                if domain.upper_bound <= atomic.value() {
215                    return true;
216                }
217
218                if domain.holes.contains(&atomic.value()) {
219                    return true;
220                }
221
222                false
223            }
224        }
225    }
226}
227
228/// A domain inside the variable state.
229#[derive(Clone, Debug)]
230pub struct Domain {
231    lower_bound: IntExt,
232    upper_bound: IntExt,
233    holes: BTreeSet<i32>,
234}
235
236impl Domain {
237    /// Create a domain that contains all integers.
238    pub fn all_integers() -> Domain {
239        Domain {
240            lower_bound: IntExt::NegativeInf,
241            upper_bound: IntExt::PositiveInf,
242            holes: BTreeSet::default(),
243        }
244    }
245
246    /// Create an empty/inconsistent domain.
247    pub fn empty() -> Domain {
248        Domain {
249            lower_bound: IntExt::PositiveInf,
250            upper_bound: IntExt::NegativeInf,
251            holes: BTreeSet::default(),
252        }
253    }
254
255    /// Construct a new domain.
256    pub fn new(lower_bound: IntExt, upper_bound: IntExt, holes: BTreeSet<i32>) -> Self {
257        let mut domain = Domain::all_integers();
258        domain.holes = holes;
259
260        if let IntExt::Int(bound) = lower_bound {
261            domain.tighten_lower_bound(bound);
262        }
263
264        if let IntExt::Int(bound) = upper_bound {
265            domain.tighten_upper_bound(bound);
266        }
267
268        domain
269    }
270
271    /// Get the holes in the domain.
272    pub fn holes(&self) -> &BTreeSet<i32> {
273        &self.holes
274    }
275
276    /// Get the lower bound of the domain.
277    pub fn lower_bound(&self) -> IntExt {
278        self.lower_bound
279    }
280
281    /// Get the upper bound of the domain.
282    pub fn upper_bound(&self) -> IntExt {
283        self.upper_bound
284    }
285
286    /// Tighten the lower bound and remove any holes that are no longer strictly larger than the
287    /// lower bound.
288    fn tighten_lower_bound(&mut self, bound: i32) {
289        if self.lower_bound >= bound && !self.holes.contains(&bound) {
290            return;
291        }
292
293        self.lower_bound = IntExt::Int(bound);
294        self.holes = self.holes.split_off(&bound);
295
296        // Take care of the condition where the new bound is already a hole in the domain.
297        if self.holes.contains(&bound) {
298            self.tighten_lower_bound(bound + 1);
299        }
300    }
301
302    /// Tighten the upper bound and remove any holes that are no longer strictly smaller than the
303    /// upper bound.
304    fn tighten_upper_bound(&mut self, bound: i32) {
305        if self.upper_bound <= bound && !self.holes.contains(&bound) {
306            return;
307        }
308
309        self.upper_bound = IntExt::Int(bound);
310
311        // Note the '+ 1' to keep the elements <= the upper bound instead of <
312        // the upper bound.
313        let _ = self.holes.split_off(&(bound + 1));
314
315        // Take care of the condition where the new bound is already a hole in the domain.
316        if self.holes.contains(&bound) {
317            self.tighten_upper_bound(bound - 1);
318        }
319    }
320
321    /// Returns true if the domain contains at least one value.
322    pub fn is_consistent(&self) -> bool {
323        // No need to check holes, as the invariant of `Domain` specifies the bounds are as tight
324        // as possible, taking holes into account.
325
326        self.lower_bound <= self.upper_bound
327    }
328}
329
330/// An iterator over the values in the domain of a variable.
331#[derive(Debug)]
332pub struct DomainIterator<'a> {
333    domain: &'a Domain,
334    next_value: i32,
335}
336
337impl Iterator for DomainIterator<'_> {
338    type Item = i32;
339
340    fn next(&mut self) -> Option<Self::Item> {
341        let DomainIterator { domain, next_value } = self;
342
343        let IntExt::Int(upper_bound) = domain.upper_bound else {
344            panic!("Only finite domains can be iterated.")
345        };
346
347        loop {
348            // We have completed iterating the domain.
349            if *next_value > upper_bound {
350                return None;
351            }
352
353            let value = *next_value;
354            *next_value += 1;
355
356            // The next value is not part of the domain.
357            if domain.holes.contains(&value) {
358                continue;
359            }
360
361            // Here the value is part of the domain, so we yield it.
362            return Some(value);
363        }
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use crate::TestAtomic;
371
372    #[test]
373    fn domain_iterator_unbounded() {
374        let state = VariableState::<TestAtomic>::default();
375        let iterator = state.iter_domain(&"x1");
376
377        assert!(iterator.is_none());
378    }
379
380    #[test]
381    fn domain_iterator_unbounded_lower_bound() {
382        let mut state = VariableState::default();
383
384        let _ = state.apply(&TestAtomic {
385            name: "x1",
386            comparison: Comparison::LessEqual,
387            value: 5,
388        });
389
390        let iterator = state.iter_domain(&"x1");
391
392        assert!(iterator.is_none());
393    }
394
395    #[test]
396    fn domain_iterator_unbounded_upper_bound() {
397        let mut state = VariableState::default();
398
399        let _ = state.apply(&TestAtomic {
400            name: "x1",
401            comparison: Comparison::GreaterEqual,
402            value: 5,
403        });
404
405        let iterator = state.iter_domain(&"x1");
406
407        assert!(iterator.is_none());
408    }
409
410    #[test]
411    fn domain_iterator_bounded_no_holes() {
412        let mut state = VariableState::default();
413
414        let _ = state.apply(&TestAtomic {
415            name: "x1",
416            comparison: Comparison::GreaterEqual,
417            value: 5,
418        });
419
420        let _ = state.apply(&TestAtomic {
421            name: "x1",
422            comparison: Comparison::LessEqual,
423            value: 10,
424        });
425
426        let values = state
427            .iter_domain(&"x1")
428            .expect("the domain is bounded")
429            .collect::<Vec<_>>();
430
431        assert_eq!(values, vec![5, 6, 7, 8, 9, 10]);
432    }
433
434    #[test]
435    fn domain_iterator_bounded_with_holes() {
436        let mut state = VariableState::default();
437
438        let _ = state.apply(&TestAtomic {
439            name: "x1",
440            comparison: Comparison::GreaterEqual,
441            value: 5,
442        });
443
444        let _ = state.apply(&TestAtomic {
445            name: "x1",
446            comparison: Comparison::NotEqual,
447            value: 7,
448        });
449
450        let _ = state.apply(&TestAtomic {
451            name: "x1",
452            comparison: Comparison::LessEqual,
453            value: 10,
454        });
455
456        let values = state
457            .iter_domain(&"x1")
458            .expect("the domain is bounded")
459            .collect::<Vec<_>>();
460
461        assert_eq!(values, vec![5, 6, 8, 9, 10]);
462    }
463}