pumpkin_core/engine/predicates/
predicate.rs1use crate::engine::variables::DomainId;
2use crate::engine::Assignments;
3use crate::predicate;
4
5#[derive(Clone, PartialEq, Eq, Copy, Hash)]
10pub struct Predicate {
11 id: u32,
14 value: i32,
15}
16
17const LOWER_BOUND_CODE: u8 = 1;
18const UPPER_BOUND_CODE: u8 = 2;
19const EQUAL_CODE: u8 = 0;
20const NOT_EQUAL_CODE: u8 = 3;
21
22impl Predicate {
23 pub fn new(id: DomainId, predicate_type: PredicateType, value: i32) -> Self {
26 let code = match predicate_type {
27 PredicateType::LowerBound => LOWER_BOUND_CODE,
28 PredicateType::UpperBound => UPPER_BOUND_CODE,
29 PredicateType::NotEqual => NOT_EQUAL_CODE,
30 PredicateType::Equal => EQUAL_CODE,
31 };
32 let id = id.id() | (code as u32) << 30;
33 Self { id, value }
34 }
35
36 fn get_type_code(&self) -> u8 {
37 (self.id >> 30) as u8
38 }
39
40 pub fn get_predicate_type(&self) -> PredicateType {
41 (*self).into()
42 }
43}
44
45#[derive(Debug, Clone, Eq, PartialEq, Copy, Hash)]
46pub enum PredicateType {
47 LowerBound,
48 UpperBound,
49 NotEqual,
50 Equal,
51}
52impl PredicateType {
53 pub(crate) fn is_lower_bound(&self) -> bool {
54 matches!(self, PredicateType::LowerBound)
55 }
56
57 pub(crate) fn is_upper_bound(&self) -> bool {
58 matches!(self, PredicateType::UpperBound)
59 }
60
61 pub(crate) fn into_predicate(
62 self,
63 domain: DomainId,
64 assignments: &Assignments,
65 removed_value: Option<i32>,
66 ) -> Predicate {
67 match self {
68 PredicateType::LowerBound => {
69 predicate!(domain >= assignments.get_lower_bound(domain))
70 }
71 PredicateType::UpperBound => predicate!(domain <= assignments.get_upper_bound(domain)),
72 PredicateType::NotEqual => predicate!(
73 domain
74 != removed_value
75 .expect("For a `NotEqual`, the removed value should be provided")
76 ),
77 PredicateType::Equal => predicate!(
78 domain
79 == assignments.get_assigned_value(&domain).expect(
80 "Expected domain to be assigned when creating an `Equal` predicate"
81 )
82 ),
83 }
84 }
85}
86
87impl From<Predicate> for PredicateType {
88 fn from(value: Predicate) -> Self {
89 match value.get_type_code() {
90 LOWER_BOUND_CODE => Self::LowerBound,
91 UPPER_BOUND_CODE => Self::UpperBound,
92 EQUAL_CODE => Self::Equal,
93 NOT_EQUAL_CODE => Self::NotEqual,
94 code => panic!("Unknown type code {code}"),
95 }
96 }
97}
98
99impl Predicate {
100 pub(crate) fn is_mutually_exclusive_with(self, other: Predicate) -> bool {
101 let domain_id = self.get_domain();
102 let rhs = self.get_right_hand_side();
103
104 let domain_id_other = other.get_domain();
105 let rhs_other = other.get_right_hand_side();
106
107 if domain_id != domain_id_other {
108 return false;
110 }
111
112 match (self.get_predicate_type(), other.get_predicate_type()) {
113 (PredicateType::LowerBound, PredicateType::LowerBound)
114 | (PredicateType::LowerBound, PredicateType::NotEqual)
115 | (PredicateType::UpperBound, PredicateType::UpperBound)
116 | (PredicateType::UpperBound, PredicateType::NotEqual)
117 | (PredicateType::NotEqual, PredicateType::LowerBound)
118 | (PredicateType::NotEqual, PredicateType::UpperBound)
119 | (PredicateType::NotEqual, PredicateType::NotEqual) => false,
120 (PredicateType::LowerBound, PredicateType::UpperBound) => rhs > rhs_other,
121 (PredicateType::UpperBound, PredicateType::LowerBound) => rhs_other > rhs,
122 (PredicateType::LowerBound, PredicateType::Equal) => rhs > rhs_other,
123 (PredicateType::Equal, PredicateType::LowerBound) => rhs_other > rhs,
124 (PredicateType::UpperBound, PredicateType::Equal) => rhs < rhs_other,
125 (PredicateType::Equal, PredicateType::UpperBound) => rhs_other < rhs,
126 (PredicateType::NotEqual, PredicateType::Equal)
127 | (PredicateType::Equal, PredicateType::NotEqual) => rhs == rhs_other,
128 (PredicateType::Equal, PredicateType::Equal) => rhs != rhs_other,
129 }
130 }
131 pub fn is_equality_predicate(&self) -> bool {
132 self.get_type_code() == EQUAL_CODE
133 }
134
135 pub fn is_lower_bound_predicate(&self) -> bool {
136 self.get_type_code() == LOWER_BOUND_CODE
137 }
138
139 pub fn is_upper_bound_predicate(&self) -> bool {
140 self.get_type_code() == UPPER_BOUND_CODE
141 }
142
143 pub fn is_not_equal_predicate(&self) -> bool {
144 self.get_type_code() == NOT_EQUAL_CODE
145 }
146
147 pub fn get_domain(&self) -> DomainId {
149 DomainId::new(0b00111111_11111111_11111111_11111111 & self.id)
150 }
151
152 pub fn get_right_hand_side(&self) -> i32 {
153 self.value
154 }
155
156 pub fn trivially_true() -> Predicate {
157 let domain_id = DomainId::new(0);
160 predicate!(domain_id == 1)
161 }
162
163 pub fn trivially_false() -> Predicate {
164 let domain_id = DomainId::new(0);
167 predicate!(domain_id != 1)
168 }
169}
170
171impl std::ops::Not for Predicate {
172 type Output = Predicate;
173
174 fn not(self) -> Self::Output {
175 let domain_id = self.get_domain();
176 let value = self.get_right_hand_side();
177
178 match self.get_predicate_type() {
179 PredicateType::LowerBound => predicate!(domain_id <= value - 1),
180 PredicateType::UpperBound => predicate!(domain_id >= value + 1),
181 PredicateType::NotEqual => predicate!(domain_id == value),
182 PredicateType::Equal => predicate!(domain_id != value),
183 }
184 }
185}
186
187impl std::fmt::Display for Predicate {
188 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189 if *self == Predicate::trivially_true() {
190 write!(f, "[True]")
191 } else if *self == Predicate::trivially_false() {
192 write!(f, "[False]")
193 } else {
194 let domain_id = self.get_domain();
195 let rhs = self.get_right_hand_side();
196
197 match self.get_predicate_type() {
198 PredicateType::LowerBound => write!(f, "[{domain_id} >= {rhs}]"),
199 PredicateType::UpperBound => write!(f, "[{domain_id} <= {rhs}]"),
200 PredicateType::NotEqual => write!(f, "[{domain_id} != {rhs}]"),
201 PredicateType::Equal => write!(f, "[{domain_id} == {rhs}]"),
202 }
203 }
204 }
205}
206
207impl std::fmt::Debug for Predicate {
208 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209 write!(f, "{self}")
210 }
211}
212
213#[cfg(test)]
214mod test {
215 use super::Predicate;
216 use crate::predicate;
217 use crate::variables::DomainId;
218
219 #[test]
220 fn are_mutually_exclusive() {
221 let domain_id = DomainId::new(0);
222
223 assert!(!predicate!(domain_id >= 5).is_mutually_exclusive_with(predicate!(domain_id >= 7)));
224 assert!(!predicate!(domain_id >= 5).is_mutually_exclusive_with(predicate!(domain_id != 2)));
225 assert!(!predicate!(domain_id <= 5).is_mutually_exclusive_with(predicate!(domain_id <= 8)));
226 assert!(!predicate!(domain_id <= 5).is_mutually_exclusive_with(predicate!(domain_id != 8)));
227 assert!(!predicate!(domain_id != 9).is_mutually_exclusive_with(predicate!(domain_id >= 8)));
228 assert!(!predicate!(domain_id != 9).is_mutually_exclusive_with(predicate!(domain_id <= 8)));
229 assert!(!predicate!(domain_id != 9).is_mutually_exclusive_with(predicate!(domain_id != 8)));
230
231 assert!(predicate!(domain_id <= 7).is_mutually_exclusive_with(predicate!(domain_id >= 8)));
232 assert!(predicate!(domain_id >= 8).is_mutually_exclusive_with(predicate!(domain_id <= 7)));
233
234 assert!(predicate!(domain_id >= 8).is_mutually_exclusive_with(predicate!(domain_id == 7)));
235 assert!(predicate!(domain_id == 7).is_mutually_exclusive_with(predicate!(domain_id >= 8)));
236
237 assert!(predicate!(domain_id == 7).is_mutually_exclusive_with(predicate!(domain_id <= 6)));
238 assert!(predicate!(domain_id <= 6).is_mutually_exclusive_with(predicate!(domain_id == 7)));
239
240 assert!(predicate!(domain_id != 8).is_mutually_exclusive_with(predicate!(domain_id == 8)));
241 assert!(predicate!(domain_id == 8).is_mutually_exclusive_with(predicate!(domain_id != 8)));
242
243 assert!(predicate!(domain_id == 7).is_mutually_exclusive_with(predicate!(domain_id == 8)));
244 }
245
246 #[test]
247 fn negating_trivially_true_predicate() {
248 let trivially_true = Predicate::trivially_true();
249 let trivially_false = Predicate::trivially_false();
250 assert!(!trivially_true == trivially_false);
251 }
252
253 #[test]
254 fn negating_trivially_false_predicate() {
255 let trivially_true = Predicate::trivially_true();
256 let trivially_false = Predicate::trivially_false();
257 assert!(!trivially_false == trivially_true);
258 }
259}