Skip to main content

pumpkin_checking/
union.rs

1use std::collections::BTreeSet;
2
3use crate::AtomicConstraint;
4use crate::CheckerVariable;
5use crate::IntExt;
6use crate::VariableState;
7
8/// Calculates the union between multiple variable domains.
9///
10/// Can be used incrementally through [`Union::add`]. Using [`Union::reset`] the union is reset to
11/// an empty state.
12#[derive(Clone, Debug)]
13pub struct Union {
14    lower_bound: IntExt,
15    upper_bound: IntExt,
16    all_holes: BTreeSet<i32>,
17}
18
19impl Union {
20    /// Create an empty union.
21    pub fn empty() -> Self {
22        Union {
23            lower_bound: IntExt::PositiveInf,
24            upper_bound: IntExt::NegativeInf,
25            all_holes: BTreeSet::default(),
26        }
27    }
28
29    /// Get the lower bound of the union of all added domains.
30    ///
31    /// If the union is empty, this is [`IntExt::PositiveInf`].
32    pub fn lower_bound(&self) -> IntExt {
33        self.lower_bound
34    }
35
36    /// Get the upper bound of the union of all added domains.
37    ///
38    /// If the union is empty, this is [`IntExt::NegativeInf`].
39    pub fn upper_bound(&self) -> IntExt {
40        self.upper_bound
41    }
42
43    /// Get the holes in the union of all added domains.
44    pub fn holes(&self) -> impl Iterator<Item = i32> + '_ {
45        self.all_holes.iter().copied()
46    }
47
48    /// Test whether `value` is in the union.
49    pub fn contains(&self, value: i32) -> bool {
50        self.lower_bound <= value && value <= self.upper_bound && !self.all_holes.contains(&value)
51    }
52
53    /// If the domain is finite, returns the number of elements.
54    pub fn size(&self) -> Option<usize> {
55        let IntExt::Int(lower_bound) = self.lower_bound else {
56            return None;
57        };
58        let IntExt::Int(upper_bound) = self.upper_bound else {
59            return None;
60        };
61
62        let size = upper_bound.abs_diff(lower_bound) as usize + 1 - self.all_holes.len();
63        Some(size)
64    }
65
66    /// Rest the union to an empty state.
67    pub fn reset(&mut self) {
68        self.lower_bound = IntExt::PositiveInf;
69        self.upper_bound = IntExt::NegativeInf;
70        self.all_holes.clear();
71    }
72
73    /// Add a variable to the union.
74    pub fn add<Atomic: AtomicConstraint>(
75        &mut self,
76        state: &VariableState<Atomic>,
77        variable: &impl CheckerVariable<Atomic>,
78    ) {
79        // All holes that are already in the union but not in the new domain should be removed.
80        self.all_holes
81            .retain(|&value| !variable.induced_domain_contains(state, value));
82
83        let other_holes = variable
84            .induced_holes(state)
85            .filter(|&value| value < self.lower_bound || value > self.upper_bound);
86
87        self.all_holes.extend(other_holes);
88
89        let variable_lb = variable.induced_lower_bound(state);
90        let variable_ub = variable.induced_upper_bound(state);
91
92        let additional_holes_below_self_lb = match (self.lower_bound, variable_ub) {
93            (IntExt::Int(self_lb), IntExt::Int(variable_ub)) if self_lb > variable_ub => {
94                variable_ub + 1..self_lb
95            }
96
97            _ => 0..0,
98        };
99
100        let additional_holes_above_self_ub = match (self.upper_bound, variable_lb) {
101            (IntExt::Int(self_ub), IntExt::Int(variable_lb)) if self_ub < variable_lb => {
102                self_ub + 1..variable_lb
103            }
104
105            _ => 0..0,
106        };
107
108        self.all_holes.extend(additional_holes_below_self_lb);
109        self.all_holes.extend(additional_holes_above_self_ub);
110
111        self.lower_bound = self.lower_bound.min(variable_lb);
112        self.upper_bound = self.upper_bound.max(variable_ub);
113    }
114
115    /// Returns `true` if the union contains no elements.
116    pub fn is_consistent(&self) -> bool {
117        self.lower_bound <= self.upper_bound
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use crate::Comparison::*;
125    use crate::TestAtomic;
126
127    #[test]
128    fn empty_union_is_inconsistent() {
129        let union = Union::empty();
130        assert!(!union.is_consistent());
131    }
132
133    #[test]
134    fn reset_clear_domain() {
135        let state = VariableState::prepare_for_conflict_check(
136            [TestAtomic {
137                name: "x",
138                comparison: GreaterEqual,
139                value: 4,
140            }],
141            None,
142        )
143        .expect("not inconsistent");
144
145        let mut union = Union::empty();
146
147        union.add(&state, &"x");
148        union.reset();
149        assert!(!union.is_consistent());
150    }
151
152    #[test]
153    fn adding_single_variable_copies_domain() {
154        let state = VariableState::prepare_for_conflict_check(
155            [TestAtomic {
156                name: "x",
157                comparison: GreaterEqual,
158                value: 4,
159            }],
160            None,
161        )
162        .expect("not inconsistent");
163
164        let mut union = Union::empty();
165
166        union.add(&state, &"x");
167        assert!(union.is_consistent());
168
169        assert_eq!(IntExt::<i32>::Int(4), union.lower_bound());
170        assert_eq!(IntExt::<i32>::PositiveInf, union.upper_bound());
171        assert!(union.holes().next().is_none());
172    }
173
174    #[test]
175    fn union_of_non_overlapping_domains() {
176        let state = VariableState::prepare_for_conflict_check(
177            [
178                TestAtomic {
179                    name: "x",
180                    comparison: GreaterEqual,
181                    value: 4,
182                },
183                TestAtomic {
184                    name: "y",
185                    comparison: LessEqual,
186                    value: 3,
187                },
188            ],
189            None,
190        )
191        .expect("not inconsistent");
192
193        let mut union = Union::empty();
194
195        union.add(&state, &"x");
196        union.add(&state, &"y");
197
198        assert_eq!(IntExt::<i32>::NegativeInf, union.lower_bound());
199        assert_eq!(IntExt::<i32>::PositiveInf, union.upper_bound());
200        assert!(union.holes().next().is_none());
201    }
202
203    #[test]
204    fn union_of_domain_with_hole_outside_bounds() {
205        let state = VariableState::prepare_for_conflict_check(
206            [
207                TestAtomic {
208                    name: "x",
209                    comparison: GreaterEqual,
210                    value: 4,
211                },
212                TestAtomic {
213                    name: "y",
214                    comparison: NotEqual,
215                    value: 2,
216                },
217            ],
218            None,
219        )
220        .expect("not inconsistent");
221
222        let mut union = Union::empty();
223
224        union.add(&state, &"x");
225        union.add(&state, &"y");
226
227        assert_eq!(IntExt::<i32>::NegativeInf, union.lower_bound());
228        assert_eq!(IntExt::<i32>::PositiveInf, union.upper_bound());
229        assert_eq!(vec![2], union.holes().collect::<Vec<_>>());
230    }
231
232    #[test]
233    fn union_with_hole_removed_by_add() {
234        let state = VariableState::prepare_for_conflict_check(
235            [
236                TestAtomic {
237                    name: "x",
238                    comparison: GreaterEqual,
239                    value: 4,
240                },
241                TestAtomic {
242                    name: "x",
243                    comparison: NotEqual,
244                    value: 5,
245                },
246                TestAtomic {
247                    name: "y",
248                    comparison: NotEqual,
249                    value: 4,
250                },
251            ],
252            None,
253        )
254        .expect("not inconsistent");
255
256        let mut union = Union::empty();
257
258        union.add(&state, &"x");
259        union.add(&state, &"y");
260
261        assert_eq!(IntExt::<i32>::NegativeInf, union.lower_bound());
262        assert_eq!(IntExt::<i32>::PositiveInf, union.upper_bound());
263        assert!(union.holes().next().is_none());
264    }
265
266    #[test]
267    fn holes_in_union_and_new_variable_should_be_kept() {
268        let state = VariableState::prepare_for_conflict_check(
269            [
270                TestAtomic {
271                    name: "x",
272                    comparison: NotEqual,
273                    value: 5,
274                },
275                TestAtomic {
276                    name: "y",
277                    comparison: NotEqual,
278                    value: 5,
279                },
280            ],
281            None,
282        )
283        .expect("not inconsistent");
284
285        let mut union = Union::empty();
286        union.add(&state, &"x");
287        union.add(&state, &"y");
288
289        assert_eq!(IntExt::<i32>::NegativeInf, union.lower_bound());
290        assert_eq!(IntExt::<i32>::PositiveInf, union.upper_bound());
291        assert_eq!(vec![5], union.holes().collect::<Vec<_>>());
292    }
293
294    #[test]
295    fn union_of_fixed_domains_with_gap() {
296        let state = VariableState::prepare_for_conflict_check(
297            [
298                TestAtomic {
299                    name: "x46",
300                    comparison: Equal,
301                    value: 4,
302                },
303                TestAtomic {
304                    name: "x42",
305                    comparison: Equal,
306                    value: 1,
307                },
308            ],
309            None,
310        )
311        .expect("not inconsistent");
312
313        let mut union = Union::empty();
314        assert_eq!(union.size(), None);
315
316        union.add(&state, &"x46");
317        assert_eq!(union.size(), Some(1));
318
319        union.add(&state, &"x42");
320
321        let holes = union.holes().collect::<BTreeSet<_>>();
322        assert_eq!(BTreeSet::from([2, 3]), holes);
323
324        assert_eq!(union.lower_bound(), IntExt::Int(1));
325        assert_eq!(union.upper_bound(), IntExt::Int(4));
326    }
327}