Skip to main content

pumpkin_core/engine/predicates/
predicate.rs

1use enumset::EnumSetType;
2use pumpkin_checking::AtomicConstraint;
3
4use crate::engine::Assignments;
5use crate::engine::variables::DomainId;
6use crate::predicate;
7use crate::propagation::DomainEvent;
8
9/// Representation of a domain operation, also known as an atomic constraint. It is a triple
10/// ([`DomainId`], [`PredicateType`], value).
11///
12/// To create a [`Predicate`], use [Predicate::new] or the more concise [predicate!] macro.
13#[derive(Clone, PartialEq, Eq, Copy, Hash)]
14pub struct Predicate {
15    /// The two most significant bits of the id stored in the [`Predicate`] contains the type of
16    /// predicate.
17    id: u32,
18    value: i32,
19}
20
21const LOWER_BOUND_CODE: u8 = 0;
22const UPPER_BOUND_CODE: u8 = 1;
23const NOT_EQUAL_CODE: u8 = 2;
24const EQUAL_CODE: u8 = 3;
25
26impl Predicate {
27    /// Creates a new [`Predicate`] (also known as atomic constraint) which represents a domain
28    /// operation.
29    pub fn new(id: DomainId, predicate_type: PredicateType, value: i32) -> Self {
30        let code = match predicate_type {
31            PredicateType::LowerBound => LOWER_BOUND_CODE,
32            PredicateType::UpperBound => UPPER_BOUND_CODE,
33            PredicateType::NotEqual => NOT_EQUAL_CODE,
34            PredicateType::Equal => EQUAL_CODE,
35        };
36        let id = id.id() | (code as u32) << 30;
37        Self { id, value }
38    }
39
40    fn get_type_code(&self) -> u8 {
41        (self.id >> 30) as u8
42    }
43
44    pub fn get_predicate_type(&self) -> PredicateType {
45        (*self).into()
46    }
47}
48
49#[derive(Debug, Hash, EnumSetType)]
50#[repr(u8)]
51#[enumset(repr = "u8")]
52pub enum PredicateType {
53    // Should correspond with the codes defined previously; `EnumSetType` requires that literals
54    // are used and not expressions
55    LowerBound = 0,
56    UpperBound = 1,
57    NotEqual = 2,
58    Equal = 3,
59}
60
61impl From<DomainEvent> for PredicateType {
62    fn from(value: DomainEvent) -> Self {
63        match value {
64            DomainEvent::Assign => PredicateType::Equal,
65            DomainEvent::LowerBound => PredicateType::LowerBound,
66            DomainEvent::UpperBound => PredicateType::UpperBound,
67            DomainEvent::Removal => PredicateType::NotEqual,
68        }
69    }
70}
71
72impl PredicateType {
73    pub fn is_lower_bound(&self) -> bool {
74        matches!(self, PredicateType::LowerBound)
75    }
76
77    pub fn is_upper_bound(&self) -> bool {
78        matches!(self, PredicateType::UpperBound)
79    }
80
81    pub fn is_disequality(&self) -> bool {
82        matches!(self, PredicateType::NotEqual)
83    }
84
85    pub(crate) fn into_predicate(
86        self,
87        domain: DomainId,
88        assignments: &Assignments,
89        removed_value: Option<i32>,
90    ) -> Predicate {
91        match self {
92            PredicateType::LowerBound => {
93                predicate!(domain >= assignments.get_lower_bound(domain))
94            }
95            PredicateType::UpperBound => predicate!(domain <= assignments.get_upper_bound(domain)),
96            PredicateType::NotEqual => predicate!(
97                domain
98                    != removed_value
99                        .expect("For a `NotEqual`, the removed value should be provided")
100            ),
101            PredicateType::Equal => predicate!(
102                domain
103                    == assignments.get_assigned_value(&domain).expect(
104                        "Expected domain to be assigned when creating an `Equal` predicate"
105                    )
106            ),
107        }
108    }
109}
110
111impl From<Predicate> for PredicateType {
112    fn from(value: Predicate) -> Self {
113        match value.get_type_code() {
114            LOWER_BOUND_CODE => Self::LowerBound,
115            UPPER_BOUND_CODE => Self::UpperBound,
116            EQUAL_CODE => Self::Equal,
117            NOT_EQUAL_CODE => Self::NotEqual,
118            code => panic!("Unknown type code {code}"),
119        }
120    }
121}
122
123impl PredicateType {
124    pub const fn into_bits(self) -> u8 {
125        self as _
126    }
127
128    pub const fn from_bits(value: u8) -> PredicateType {
129        match value {
130            LOWER_BOUND_CODE => PredicateType::LowerBound,
131            UPPER_BOUND_CODE => PredicateType::UpperBound,
132            EQUAL_CODE => PredicateType::Equal,
133            NOT_EQUAL_CODE => PredicateType::NotEqual,
134            _ => panic!("Unknown code"),
135        }
136    }
137}
138
139impl Predicate {
140    pub(crate) fn is_mutually_exclusive_with(self, other: Predicate) -> bool {
141        let domain_id = self.get_domain();
142        let rhs = self.get_right_hand_side();
143
144        let domain_id_other = other.get_domain();
145        let rhs_other = other.get_right_hand_side();
146
147        if domain_id != domain_id_other {
148            // Domain Ids do not match
149            return false;
150        }
151
152        match (self.get_predicate_type(), other.get_predicate_type()) {
153            (PredicateType::LowerBound, PredicateType::LowerBound)
154            | (PredicateType::LowerBound, PredicateType::NotEqual)
155            | (PredicateType::UpperBound, PredicateType::UpperBound)
156            | (PredicateType::UpperBound, PredicateType::NotEqual)
157            | (PredicateType::NotEqual, PredicateType::LowerBound)
158            | (PredicateType::NotEqual, PredicateType::UpperBound)
159            | (PredicateType::NotEqual, PredicateType::NotEqual) => false,
160            (PredicateType::LowerBound, PredicateType::UpperBound) => rhs > rhs_other,
161            (PredicateType::UpperBound, PredicateType::LowerBound) => rhs_other > rhs,
162            (PredicateType::LowerBound, PredicateType::Equal) => rhs > rhs_other,
163            (PredicateType::Equal, PredicateType::LowerBound) => rhs_other > rhs,
164            (PredicateType::UpperBound, PredicateType::Equal) => rhs < rhs_other,
165            (PredicateType::Equal, PredicateType::UpperBound) => rhs_other < rhs,
166            (PredicateType::NotEqual, PredicateType::Equal)
167            | (PredicateType::Equal, PredicateType::NotEqual) => rhs == rhs_other,
168            (PredicateType::Equal, PredicateType::Equal) => rhs != rhs_other,
169        }
170    }
171    pub fn is_equality_predicate(&self) -> bool {
172        self.get_type_code() == EQUAL_CODE
173    }
174
175    pub fn is_lower_bound_predicate(&self) -> bool {
176        self.get_type_code() == LOWER_BOUND_CODE
177    }
178
179    pub fn is_upper_bound_predicate(&self) -> bool {
180        self.get_type_code() == UPPER_BOUND_CODE
181    }
182
183    pub fn is_not_equal_predicate(&self) -> bool {
184        self.get_type_code() == NOT_EQUAL_CODE
185    }
186
187    /// Returns the [`DomainId`] of the [`Predicate`]
188    pub fn get_domain(&self) -> DomainId {
189        DomainId::new(0b00111111_11111111_11111111_11111111 & self.id)
190    }
191
192    pub fn get_right_hand_side(&self) -> i32 {
193        self.value
194    }
195
196    pub fn trivially_true() -> Predicate {
197        // By convention, there is a dummy 0-1 variable set to one at root.
198        // We use it to denote the trivially true predicate.
199        let domain_id = DomainId::new(0);
200        predicate!(domain_id == 1)
201    }
202
203    pub fn trivially_false() -> Predicate {
204        // By convention, there is a dummy 0-1 variable set to one at root.
205        // We use it to denote the trivially true predicate.
206        let domain_id = DomainId::new(0);
207        predicate!(domain_id != 1)
208    }
209}
210
211impl std::ops::Not for Predicate {
212    type Output = Predicate;
213
214    fn not(self) -> Self::Output {
215        let domain_id = self.get_domain();
216        let value = self.get_right_hand_side();
217
218        match self.get_predicate_type() {
219            PredicateType::LowerBound => predicate!(domain_id <= value - 1),
220            PredicateType::UpperBound => predicate!(domain_id >= value + 1),
221            PredicateType::NotEqual => predicate!(domain_id == value),
222            PredicateType::Equal => predicate!(domain_id != value),
223        }
224    }
225}
226
227impl std::fmt::Display for Predicate {
228    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229        if *self == Predicate::trivially_true() {
230            write!(f, "[True]")
231        } else if *self == Predicate::trivially_false() {
232            write!(f, "[False]")
233        } else {
234            let domain_id = self.get_domain();
235            let rhs = self.get_right_hand_side();
236
237            match self.get_predicate_type() {
238                PredicateType::LowerBound => write!(f, "[{domain_id} >= {rhs}]"),
239                PredicateType::UpperBound => write!(f, "[{domain_id} <= {rhs}]"),
240                PredicateType::NotEqual => write!(f, "[{domain_id} != {rhs}]"),
241                PredicateType::Equal => write!(f, "[{domain_id} == {rhs}]"),
242            }
243        }
244    }
245}
246
247impl std::fmt::Debug for Predicate {
248    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249        write!(f, "{self}")
250    }
251}
252
253impl AtomicConstraint for Predicate {
254    type Identifier = DomainId;
255
256    fn identifier(&self) -> Self::Identifier {
257        self.get_domain()
258    }
259
260    fn comparison(&self) -> pumpkin_checking::Comparison {
261        match self.get_predicate_type() {
262            PredicateType::LowerBound => pumpkin_checking::Comparison::GreaterEqual,
263            PredicateType::UpperBound => pumpkin_checking::Comparison::LessEqual,
264            PredicateType::NotEqual => pumpkin_checking::Comparison::NotEqual,
265            PredicateType::Equal => pumpkin_checking::Comparison::Equal,
266        }
267    }
268
269    fn value(&self) -> i32 {
270        self.get_right_hand_side()
271    }
272
273    fn negate(&self) -> Self {
274        !*self
275    }
276}
277
278#[cfg(test)]
279mod test {
280    use super::Predicate;
281    use crate::predicate;
282    use crate::variables::DomainId;
283
284    #[test]
285    fn are_mutually_exclusive() {
286        let domain_id = DomainId::new(0);
287
288        assert!(!predicate!(domain_id >= 5).is_mutually_exclusive_with(predicate!(domain_id >= 7)));
289        assert!(!predicate!(domain_id >= 5).is_mutually_exclusive_with(predicate!(domain_id != 2)));
290        assert!(!predicate!(domain_id <= 5).is_mutually_exclusive_with(predicate!(domain_id <= 8)));
291        assert!(!predicate!(domain_id <= 5).is_mutually_exclusive_with(predicate!(domain_id != 8)));
292        assert!(!predicate!(domain_id != 9).is_mutually_exclusive_with(predicate!(domain_id >= 8)));
293        assert!(!predicate!(domain_id != 9).is_mutually_exclusive_with(predicate!(domain_id <= 8)));
294        assert!(!predicate!(domain_id != 9).is_mutually_exclusive_with(predicate!(domain_id != 8)));
295
296        assert!(predicate!(domain_id <= 7).is_mutually_exclusive_with(predicate!(domain_id >= 8)));
297        assert!(predicate!(domain_id >= 8).is_mutually_exclusive_with(predicate!(domain_id <= 7)));
298
299        assert!(predicate!(domain_id >= 8).is_mutually_exclusive_with(predicate!(domain_id == 7)));
300        assert!(predicate!(domain_id == 7).is_mutually_exclusive_with(predicate!(domain_id >= 8)));
301
302        assert!(predicate!(domain_id == 7).is_mutually_exclusive_with(predicate!(domain_id <= 6)));
303        assert!(predicate!(domain_id <= 6).is_mutually_exclusive_with(predicate!(domain_id == 7)));
304
305        assert!(predicate!(domain_id != 8).is_mutually_exclusive_with(predicate!(domain_id == 8)));
306        assert!(predicate!(domain_id == 8).is_mutually_exclusive_with(predicate!(domain_id != 8)));
307
308        assert!(predicate!(domain_id == 7).is_mutually_exclusive_with(predicate!(domain_id == 8)));
309    }
310
311    #[test]
312    fn negating_trivially_true_predicate() {
313        let trivially_true = Predicate::trivially_true();
314        let trivially_false = Predicate::trivially_false();
315        assert!(!trivially_true == trivially_false);
316    }
317
318    #[test]
319    fn negating_trivially_false_predicate() {
320        let trivially_true = Predicate::trivially_true();
321        let trivially_false = Predicate::trivially_false();
322        assert!(!trivially_false == trivially_true);
323    }
324}