pumpkin_core/engine/predicates/
predicate.rs1use enumset::EnumSetType;
2use pumpkin_checking::AtomicConstraint;
3
4use crate::engine::Assignments;
5use crate::engine::variables::DomainId;
6use crate::predicate;
7use crate::propagation::DomainEvent;
8
9#[derive(Clone, PartialEq, Eq, Copy, Hash)]
14pub struct Predicate {
15 id: u32,
18 value: i32,
19}
20
21const LOWER_BOUND_CODE: u8 = 0;
22const UPPER_BOUND_CODE: u8 = 1;
23const NOT_EQUAL_CODE: u8 = 2;
24const EQUAL_CODE: u8 = 3;
25
26impl Predicate {
27 pub fn new(id: DomainId, predicate_type: PredicateType, value: i32) -> Self {
30 let code = match predicate_type {
31 PredicateType::LowerBound => LOWER_BOUND_CODE,
32 PredicateType::UpperBound => UPPER_BOUND_CODE,
33 PredicateType::NotEqual => NOT_EQUAL_CODE,
34 PredicateType::Equal => EQUAL_CODE,
35 };
36 let id = id.id() | (code as u32) << 30;
37 Self { id, value }
38 }
39
40 fn get_type_code(&self) -> u8 {
41 (self.id >> 30) as u8
42 }
43
44 pub fn get_predicate_type(&self) -> PredicateType {
45 (*self).into()
46 }
47}
48
49#[derive(Debug, Hash, EnumSetType)]
50#[repr(u8)]
51#[enumset(repr = "u8")]
52pub enum PredicateType {
53 LowerBound = 0,
56 UpperBound = 1,
57 NotEqual = 2,
58 Equal = 3,
59}
60
61impl From<DomainEvent> for PredicateType {
62 fn from(value: DomainEvent) -> Self {
63 match value {
64 DomainEvent::Assign => PredicateType::Equal,
65 DomainEvent::LowerBound => PredicateType::LowerBound,
66 DomainEvent::UpperBound => PredicateType::UpperBound,
67 DomainEvent::Removal => PredicateType::NotEqual,
68 }
69 }
70}
71
72impl PredicateType {
73 pub fn is_lower_bound(&self) -> bool {
74 matches!(self, PredicateType::LowerBound)
75 }
76
77 pub fn is_upper_bound(&self) -> bool {
78 matches!(self, PredicateType::UpperBound)
79 }
80
81 pub fn is_disequality(&self) -> bool {
82 matches!(self, PredicateType::NotEqual)
83 }
84
85 pub(crate) fn into_predicate(
86 self,
87 domain: DomainId,
88 assignments: &Assignments,
89 removed_value: Option<i32>,
90 ) -> Predicate {
91 match self {
92 PredicateType::LowerBound => {
93 predicate!(domain >= assignments.get_lower_bound(domain))
94 }
95 PredicateType::UpperBound => predicate!(domain <= assignments.get_upper_bound(domain)),
96 PredicateType::NotEqual => predicate!(
97 domain
98 != removed_value
99 .expect("For a `NotEqual`, the removed value should be provided")
100 ),
101 PredicateType::Equal => predicate!(
102 domain
103 == assignments.get_assigned_value(&domain).expect(
104 "Expected domain to be assigned when creating an `Equal` predicate"
105 )
106 ),
107 }
108 }
109}
110
111impl From<Predicate> for PredicateType {
112 fn from(value: Predicate) -> Self {
113 match value.get_type_code() {
114 LOWER_BOUND_CODE => Self::LowerBound,
115 UPPER_BOUND_CODE => Self::UpperBound,
116 EQUAL_CODE => Self::Equal,
117 NOT_EQUAL_CODE => Self::NotEqual,
118 code => panic!("Unknown type code {code}"),
119 }
120 }
121}
122
123impl PredicateType {
124 pub const fn into_bits(self) -> u8 {
125 self as _
126 }
127
128 pub const fn from_bits(value: u8) -> PredicateType {
129 match value {
130 LOWER_BOUND_CODE => PredicateType::LowerBound,
131 UPPER_BOUND_CODE => PredicateType::UpperBound,
132 EQUAL_CODE => PredicateType::Equal,
133 NOT_EQUAL_CODE => PredicateType::NotEqual,
134 _ => panic!("Unknown code"),
135 }
136 }
137}
138
139impl Predicate {
140 pub(crate) fn is_mutually_exclusive_with(self, other: Predicate) -> bool {
141 let domain_id = self.get_domain();
142 let rhs = self.get_right_hand_side();
143
144 let domain_id_other = other.get_domain();
145 let rhs_other = other.get_right_hand_side();
146
147 if domain_id != domain_id_other {
148 return false;
150 }
151
152 match (self.get_predicate_type(), other.get_predicate_type()) {
153 (PredicateType::LowerBound, PredicateType::LowerBound)
154 | (PredicateType::LowerBound, PredicateType::NotEqual)
155 | (PredicateType::UpperBound, PredicateType::UpperBound)
156 | (PredicateType::UpperBound, PredicateType::NotEqual)
157 | (PredicateType::NotEqual, PredicateType::LowerBound)
158 | (PredicateType::NotEqual, PredicateType::UpperBound)
159 | (PredicateType::NotEqual, PredicateType::NotEqual) => false,
160 (PredicateType::LowerBound, PredicateType::UpperBound) => rhs > rhs_other,
161 (PredicateType::UpperBound, PredicateType::LowerBound) => rhs_other > rhs,
162 (PredicateType::LowerBound, PredicateType::Equal) => rhs > rhs_other,
163 (PredicateType::Equal, PredicateType::LowerBound) => rhs_other > rhs,
164 (PredicateType::UpperBound, PredicateType::Equal) => rhs < rhs_other,
165 (PredicateType::Equal, PredicateType::UpperBound) => rhs_other < rhs,
166 (PredicateType::NotEqual, PredicateType::Equal)
167 | (PredicateType::Equal, PredicateType::NotEqual) => rhs == rhs_other,
168 (PredicateType::Equal, PredicateType::Equal) => rhs != rhs_other,
169 }
170 }
171 pub fn is_equality_predicate(&self) -> bool {
172 self.get_type_code() == EQUAL_CODE
173 }
174
175 pub fn is_lower_bound_predicate(&self) -> bool {
176 self.get_type_code() == LOWER_BOUND_CODE
177 }
178
179 pub fn is_upper_bound_predicate(&self) -> bool {
180 self.get_type_code() == UPPER_BOUND_CODE
181 }
182
183 pub fn is_not_equal_predicate(&self) -> bool {
184 self.get_type_code() == NOT_EQUAL_CODE
185 }
186
187 pub fn get_domain(&self) -> DomainId {
189 DomainId::new(0b00111111_11111111_11111111_11111111 & self.id)
190 }
191
192 pub fn get_right_hand_side(&self) -> i32 {
193 self.value
194 }
195
196 pub fn trivially_true() -> Predicate {
197 let domain_id = DomainId::new(0);
200 predicate!(domain_id == 1)
201 }
202
203 pub fn trivially_false() -> Predicate {
204 let domain_id = DomainId::new(0);
207 predicate!(domain_id != 1)
208 }
209}
210
211impl std::ops::Not for Predicate {
212 type Output = Predicate;
213
214 fn not(self) -> Self::Output {
215 let domain_id = self.get_domain();
216 let value = self.get_right_hand_side();
217
218 match self.get_predicate_type() {
219 PredicateType::LowerBound => predicate!(domain_id <= value - 1),
220 PredicateType::UpperBound => predicate!(domain_id >= value + 1),
221 PredicateType::NotEqual => predicate!(domain_id == value),
222 PredicateType::Equal => predicate!(domain_id != value),
223 }
224 }
225}
226
227impl std::fmt::Display for Predicate {
228 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229 if *self == Predicate::trivially_true() {
230 write!(f, "[True]")
231 } else if *self == Predicate::trivially_false() {
232 write!(f, "[False]")
233 } else {
234 let domain_id = self.get_domain();
235 let rhs = self.get_right_hand_side();
236
237 match self.get_predicate_type() {
238 PredicateType::LowerBound => write!(f, "[{domain_id} >= {rhs}]"),
239 PredicateType::UpperBound => write!(f, "[{domain_id} <= {rhs}]"),
240 PredicateType::NotEqual => write!(f, "[{domain_id} != {rhs}]"),
241 PredicateType::Equal => write!(f, "[{domain_id} == {rhs}]"),
242 }
243 }
244 }
245}
246
247impl std::fmt::Debug for Predicate {
248 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249 write!(f, "{self}")
250 }
251}
252
253impl AtomicConstraint for Predicate {
254 type Identifier = DomainId;
255
256 fn identifier(&self) -> Self::Identifier {
257 self.get_domain()
258 }
259
260 fn comparison(&self) -> pumpkin_checking::Comparison {
261 match self.get_predicate_type() {
262 PredicateType::LowerBound => pumpkin_checking::Comparison::GreaterEqual,
263 PredicateType::UpperBound => pumpkin_checking::Comparison::LessEqual,
264 PredicateType::NotEqual => pumpkin_checking::Comparison::NotEqual,
265 PredicateType::Equal => pumpkin_checking::Comparison::Equal,
266 }
267 }
268
269 fn value(&self) -> i32 {
270 self.get_right_hand_side()
271 }
272
273 fn negate(&self) -> Self {
274 !*self
275 }
276}
277
278#[cfg(test)]
279mod test {
280 use super::Predicate;
281 use crate::predicate;
282 use crate::variables::DomainId;
283
284 #[test]
285 fn are_mutually_exclusive() {
286 let domain_id = DomainId::new(0);
287
288 assert!(!predicate!(domain_id >= 5).is_mutually_exclusive_with(predicate!(domain_id >= 7)));
289 assert!(!predicate!(domain_id >= 5).is_mutually_exclusive_with(predicate!(domain_id != 2)));
290 assert!(!predicate!(domain_id <= 5).is_mutually_exclusive_with(predicate!(domain_id <= 8)));
291 assert!(!predicate!(domain_id <= 5).is_mutually_exclusive_with(predicate!(domain_id != 8)));
292 assert!(!predicate!(domain_id != 9).is_mutually_exclusive_with(predicate!(domain_id >= 8)));
293 assert!(!predicate!(domain_id != 9).is_mutually_exclusive_with(predicate!(domain_id <= 8)));
294 assert!(!predicate!(domain_id != 9).is_mutually_exclusive_with(predicate!(domain_id != 8)));
295
296 assert!(predicate!(domain_id <= 7).is_mutually_exclusive_with(predicate!(domain_id >= 8)));
297 assert!(predicate!(domain_id >= 8).is_mutually_exclusive_with(predicate!(domain_id <= 7)));
298
299 assert!(predicate!(domain_id >= 8).is_mutually_exclusive_with(predicate!(domain_id == 7)));
300 assert!(predicate!(domain_id == 7).is_mutually_exclusive_with(predicate!(domain_id >= 8)));
301
302 assert!(predicate!(domain_id == 7).is_mutually_exclusive_with(predicate!(domain_id <= 6)));
303 assert!(predicate!(domain_id <= 6).is_mutually_exclusive_with(predicate!(domain_id == 7)));
304
305 assert!(predicate!(domain_id != 8).is_mutually_exclusive_with(predicate!(domain_id == 8)));
306 assert!(predicate!(domain_id == 8).is_mutually_exclusive_with(predicate!(domain_id != 8)));
307
308 assert!(predicate!(domain_id == 7).is_mutually_exclusive_with(predicate!(domain_id == 8)));
309 }
310
311 #[test]
312 fn negating_trivially_true_predicate() {
313 let trivially_true = Predicate::trivially_true();
314 let trivially_false = Predicate::trivially_false();
315 assert!(!trivially_true == trivially_false);
316 }
317
318 #[test]
319 fn negating_trivially_false_predicate() {
320 let trivially_true = Predicate::trivially_true();
321 let trivially_false = Predicate::trivially_false();
322 assert!(!trivially_false == trivially_true);
323 }
324}