pumpkin_core/engine/predicates/
predicate.rs1use crate::engine::variables::DomainId;
2use crate::engine::Assignments;
3use crate::predicate;
4
5#[derive(Clone, PartialEq, Eq, Copy, Hash)]
10pub struct Predicate {
11 id: u32,
14 value: i32,
15}
16
17const LOWER_BOUND_CODE: u8 = 1;
18const UPPER_BOUND_CODE: u8 = 2;
19const EQUAL_CODE: u8 = 0;
20const NOT_EQUAL_CODE: u8 = 3;
21
22impl Predicate {
23 pub fn new(id: DomainId, predicate_type: PredicateType, value: i32) -> Self {
26 let code = match predicate_type {
27 PredicateType::LowerBound => LOWER_BOUND_CODE,
28 PredicateType::UpperBound => UPPER_BOUND_CODE,
29 PredicateType::NotEqual => NOT_EQUAL_CODE,
30 PredicateType::Equal => EQUAL_CODE,
31 };
32 let id = id.id() | (code as u32) << 30;
33 Self { id, value }
34 }
35
36 fn get_type_code(&self) -> u8 {
37 (self.id >> 30) as u8
38 }
39
40 pub fn get_predicate_type(&self) -> PredicateType {
41 (*self).into()
42 }
43}
44
45#[derive(Debug, Clone, Eq, PartialEq, Copy, Hash)]
46#[repr(u8)]
47pub enum PredicateType {
48 LowerBound = LOWER_BOUND_CODE,
49 UpperBound = UPPER_BOUND_CODE,
50 NotEqual = NOT_EQUAL_CODE,
51 Equal = EQUAL_CODE,
52}
53impl PredicateType {
54 pub(crate) fn is_lower_bound(&self) -> bool {
55 matches!(self, PredicateType::LowerBound)
56 }
57
58 pub(crate) fn is_upper_bound(&self) -> bool {
59 matches!(self, PredicateType::UpperBound)
60 }
61
62 pub(crate) fn is_disequality(&self) -> bool {
63 matches!(self, PredicateType::NotEqual)
64 }
65
66 pub(crate) fn into_predicate(
67 self,
68 domain: DomainId,
69 assignments: &Assignments,
70 removed_value: Option<i32>,
71 ) -> Predicate {
72 match self {
73 PredicateType::LowerBound => {
74 predicate!(domain >= assignments.get_lower_bound(domain))
75 }
76 PredicateType::UpperBound => predicate!(domain <= assignments.get_upper_bound(domain)),
77 PredicateType::NotEqual => predicate!(
78 domain
79 != removed_value
80 .expect("For a `NotEqual`, the removed value should be provided")
81 ),
82 PredicateType::Equal => predicate!(
83 domain
84 == assignments.get_assigned_value(&domain).expect(
85 "Expected domain to be assigned when creating an `Equal` predicate"
86 )
87 ),
88 }
89 }
90}
91
92impl From<Predicate> for PredicateType {
93 fn from(value: Predicate) -> Self {
94 match value.get_type_code() {
95 LOWER_BOUND_CODE => Self::LowerBound,
96 UPPER_BOUND_CODE => Self::UpperBound,
97 EQUAL_CODE => Self::Equal,
98 NOT_EQUAL_CODE => Self::NotEqual,
99 code => panic!("Unknown type code {code}"),
100 }
101 }
102}
103
104impl PredicateType {
105 pub const fn into_bits(self) -> u8 {
106 self as _
107 }
108
109 pub const fn from_bits(value: u8) -> PredicateType {
110 match value {
111 LOWER_BOUND_CODE => PredicateType::LowerBound,
112 UPPER_BOUND_CODE => PredicateType::UpperBound,
113 EQUAL_CODE => PredicateType::Equal,
114 NOT_EQUAL_CODE => PredicateType::NotEqual,
115 _ => panic!("Unknown code"),
116 }
117 }
118}
119
120impl Predicate {
121 pub(crate) fn is_mutually_exclusive_with(self, other: Predicate) -> bool {
122 let domain_id = self.get_domain();
123 let rhs = self.get_right_hand_side();
124
125 let domain_id_other = other.get_domain();
126 let rhs_other = other.get_right_hand_side();
127
128 if domain_id != domain_id_other {
129 return false;
131 }
132
133 match (self.get_predicate_type(), other.get_predicate_type()) {
134 (PredicateType::LowerBound, PredicateType::LowerBound)
135 | (PredicateType::LowerBound, PredicateType::NotEqual)
136 | (PredicateType::UpperBound, PredicateType::UpperBound)
137 | (PredicateType::UpperBound, PredicateType::NotEqual)
138 | (PredicateType::NotEqual, PredicateType::LowerBound)
139 | (PredicateType::NotEqual, PredicateType::UpperBound)
140 | (PredicateType::NotEqual, PredicateType::NotEqual) => false,
141 (PredicateType::LowerBound, PredicateType::UpperBound) => rhs > rhs_other,
142 (PredicateType::UpperBound, PredicateType::LowerBound) => rhs_other > rhs,
143 (PredicateType::LowerBound, PredicateType::Equal) => rhs > rhs_other,
144 (PredicateType::Equal, PredicateType::LowerBound) => rhs_other > rhs,
145 (PredicateType::UpperBound, PredicateType::Equal) => rhs < rhs_other,
146 (PredicateType::Equal, PredicateType::UpperBound) => rhs_other < rhs,
147 (PredicateType::NotEqual, PredicateType::Equal)
148 | (PredicateType::Equal, PredicateType::NotEqual) => rhs == rhs_other,
149 (PredicateType::Equal, PredicateType::Equal) => rhs != rhs_other,
150 }
151 }
152 pub fn is_equality_predicate(&self) -> bool {
153 self.get_type_code() == EQUAL_CODE
154 }
155
156 pub fn is_lower_bound_predicate(&self) -> bool {
157 self.get_type_code() == LOWER_BOUND_CODE
158 }
159
160 pub fn is_upper_bound_predicate(&self) -> bool {
161 self.get_type_code() == UPPER_BOUND_CODE
162 }
163
164 pub fn is_not_equal_predicate(&self) -> bool {
165 self.get_type_code() == NOT_EQUAL_CODE
166 }
167
168 pub fn get_domain(&self) -> DomainId {
170 DomainId::new(0b00111111_11111111_11111111_11111111 & self.id)
171 }
172
173 pub fn get_right_hand_side(&self) -> i32 {
174 self.value
175 }
176
177 pub fn trivially_true() -> Predicate {
178 let domain_id = DomainId::new(0);
181 predicate!(domain_id == 1)
182 }
183
184 pub fn trivially_false() -> Predicate {
185 let domain_id = DomainId::new(0);
188 predicate!(domain_id != 1)
189 }
190}
191
192impl std::ops::Not for Predicate {
193 type Output = Predicate;
194
195 fn not(self) -> Self::Output {
196 let domain_id = self.get_domain();
197 let value = self.get_right_hand_side();
198
199 match self.get_predicate_type() {
200 PredicateType::LowerBound => predicate!(domain_id <= value - 1),
201 PredicateType::UpperBound => predicate!(domain_id >= value + 1),
202 PredicateType::NotEqual => predicate!(domain_id == value),
203 PredicateType::Equal => predicate!(domain_id != value),
204 }
205 }
206}
207
208impl std::fmt::Display for Predicate {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 if *self == Predicate::trivially_true() {
211 write!(f, "[True]")
212 } else if *self == Predicate::trivially_false() {
213 write!(f, "[False]")
214 } else {
215 let domain_id = self.get_domain();
216 let rhs = self.get_right_hand_side();
217
218 match self.get_predicate_type() {
219 PredicateType::LowerBound => write!(f, "[{domain_id} >= {rhs}]"),
220 PredicateType::UpperBound => write!(f, "[{domain_id} <= {rhs}]"),
221 PredicateType::NotEqual => write!(f, "[{domain_id} != {rhs}]"),
222 PredicateType::Equal => write!(f, "[{domain_id} == {rhs}]"),
223 }
224 }
225 }
226}
227
228impl std::fmt::Debug for Predicate {
229 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230 write!(f, "{self}")
231 }
232}
233
234#[cfg(test)]
235mod test {
236 use super::Predicate;
237 use crate::predicate;
238 use crate::variables::DomainId;
239
240 #[test]
241 fn are_mutually_exclusive() {
242 let domain_id = DomainId::new(0);
243
244 assert!(!predicate!(domain_id >= 5).is_mutually_exclusive_with(predicate!(domain_id >= 7)));
245 assert!(!predicate!(domain_id >= 5).is_mutually_exclusive_with(predicate!(domain_id != 2)));
246 assert!(!predicate!(domain_id <= 5).is_mutually_exclusive_with(predicate!(domain_id <= 8)));
247 assert!(!predicate!(domain_id <= 5).is_mutually_exclusive_with(predicate!(domain_id != 8)));
248 assert!(!predicate!(domain_id != 9).is_mutually_exclusive_with(predicate!(domain_id >= 8)));
249 assert!(!predicate!(domain_id != 9).is_mutually_exclusive_with(predicate!(domain_id <= 8)));
250 assert!(!predicate!(domain_id != 9).is_mutually_exclusive_with(predicate!(domain_id != 8)));
251
252 assert!(predicate!(domain_id <= 7).is_mutually_exclusive_with(predicate!(domain_id >= 8)));
253 assert!(predicate!(domain_id >= 8).is_mutually_exclusive_with(predicate!(domain_id <= 7)));
254
255 assert!(predicate!(domain_id >= 8).is_mutually_exclusive_with(predicate!(domain_id == 7)));
256 assert!(predicate!(domain_id == 7).is_mutually_exclusive_with(predicate!(domain_id >= 8)));
257
258 assert!(predicate!(domain_id == 7).is_mutually_exclusive_with(predicate!(domain_id <= 6)));
259 assert!(predicate!(domain_id <= 6).is_mutually_exclusive_with(predicate!(domain_id == 7)));
260
261 assert!(predicate!(domain_id != 8).is_mutually_exclusive_with(predicate!(domain_id == 8)));
262 assert!(predicate!(domain_id == 8).is_mutually_exclusive_with(predicate!(domain_id != 8)));
263
264 assert!(predicate!(domain_id == 7).is_mutually_exclusive_with(predicate!(domain_id == 8)));
265 }
266
267 #[test]
268 fn negating_trivially_true_predicate() {
269 let trivially_true = Predicate::trivially_true();
270 let trivially_false = Predicate::trivially_false();
271 assert!(!trivially_true == trivially_false);
272 }
273
274 #[test]
275 fn negating_trivially_false_predicate() {
276 let trivially_true = Predicate::trivially_true();
277 let trivially_false = Predicate::trivially_false();
278 assert!(!trivially_false == trivially_true);
279 }
280}