pumpkin_core/engine/predicates/
predicate.rs

1use crate::engine::variables::DomainId;
2use crate::engine::Assignments;
3use crate::predicate;
4
5/// Representation of a domain operation, also known as an atomic constraint. It is a triple
6/// ([`DomainId`], [`PredicateType`], value).
7///
8/// To create a [`Predicate`], use [Predicate::new] or the more concise [predicate!] macro.
9#[derive(Clone, PartialEq, Eq, Copy, Hash)]
10pub struct Predicate {
11    /// The two most significant bits of the id stored in the [`Predicate`] contains the type of
12    /// predicate.
13    id: u32,
14    value: i32,
15}
16
17const LOWER_BOUND_CODE: u8 = 1;
18const UPPER_BOUND_CODE: u8 = 2;
19const EQUAL_CODE: u8 = 0;
20const NOT_EQUAL_CODE: u8 = 3;
21
22impl Predicate {
23    /// Creates a new [`Predicate`] (also known as atomic constraint) which represents a domain
24    /// operation.
25    pub fn new(id: DomainId, predicate_type: PredicateType, value: i32) -> Self {
26        let code = match predicate_type {
27            PredicateType::LowerBound => LOWER_BOUND_CODE,
28            PredicateType::UpperBound => UPPER_BOUND_CODE,
29            PredicateType::NotEqual => NOT_EQUAL_CODE,
30            PredicateType::Equal => EQUAL_CODE,
31        };
32        let id = id.id() | (code as u32) << 30;
33        Self { id, value }
34    }
35
36    fn get_type_code(&self) -> u8 {
37        (self.id >> 30) as u8
38    }
39
40    pub fn get_predicate_type(&self) -> PredicateType {
41        (*self).into()
42    }
43}
44
45#[derive(Debug, Clone, Eq, PartialEq, Copy, Hash)]
46#[repr(u8)]
47pub enum PredicateType {
48    LowerBound = LOWER_BOUND_CODE,
49    UpperBound = UPPER_BOUND_CODE,
50    NotEqual = NOT_EQUAL_CODE,
51    Equal = EQUAL_CODE,
52}
53impl PredicateType {
54    pub(crate) fn is_lower_bound(&self) -> bool {
55        matches!(self, PredicateType::LowerBound)
56    }
57
58    pub(crate) fn is_upper_bound(&self) -> bool {
59        matches!(self, PredicateType::UpperBound)
60    }
61
62    pub(crate) fn is_disequality(&self) -> bool {
63        matches!(self, PredicateType::NotEqual)
64    }
65
66    pub(crate) fn into_predicate(
67        self,
68        domain: DomainId,
69        assignments: &Assignments,
70        removed_value: Option<i32>,
71    ) -> Predicate {
72        match self {
73            PredicateType::LowerBound => {
74                predicate!(domain >= assignments.get_lower_bound(domain))
75            }
76            PredicateType::UpperBound => predicate!(domain <= assignments.get_upper_bound(domain)),
77            PredicateType::NotEqual => predicate!(
78                domain
79                    != removed_value
80                        .expect("For a `NotEqual`, the removed value should be provided")
81            ),
82            PredicateType::Equal => predicate!(
83                domain
84                    == assignments.get_assigned_value(&domain).expect(
85                        "Expected domain to be assigned when creating an `Equal` predicate"
86                    )
87            ),
88        }
89    }
90}
91
92impl From<Predicate> for PredicateType {
93    fn from(value: Predicate) -> Self {
94        match value.get_type_code() {
95            LOWER_BOUND_CODE => Self::LowerBound,
96            UPPER_BOUND_CODE => Self::UpperBound,
97            EQUAL_CODE => Self::Equal,
98            NOT_EQUAL_CODE => Self::NotEqual,
99            code => panic!("Unknown type code {code}"),
100        }
101    }
102}
103
104impl PredicateType {
105    pub const fn into_bits(self) -> u8 {
106        self as _
107    }
108
109    pub const fn from_bits(value: u8) -> PredicateType {
110        match value {
111            LOWER_BOUND_CODE => PredicateType::LowerBound,
112            UPPER_BOUND_CODE => PredicateType::UpperBound,
113            EQUAL_CODE => PredicateType::Equal,
114            NOT_EQUAL_CODE => PredicateType::NotEqual,
115            _ => panic!("Unknown code"),
116        }
117    }
118}
119
120impl Predicate {
121    pub(crate) fn is_mutually_exclusive_with(self, other: Predicate) -> bool {
122        let domain_id = self.get_domain();
123        let rhs = self.get_right_hand_side();
124
125        let domain_id_other = other.get_domain();
126        let rhs_other = other.get_right_hand_side();
127
128        if domain_id != domain_id_other {
129            // Domain Ids do not match
130            return false;
131        }
132
133        match (self.get_predicate_type(), other.get_predicate_type()) {
134            (PredicateType::LowerBound, PredicateType::LowerBound)
135            | (PredicateType::LowerBound, PredicateType::NotEqual)
136            | (PredicateType::UpperBound, PredicateType::UpperBound)
137            | (PredicateType::UpperBound, PredicateType::NotEqual)
138            | (PredicateType::NotEqual, PredicateType::LowerBound)
139            | (PredicateType::NotEqual, PredicateType::UpperBound)
140            | (PredicateType::NotEqual, PredicateType::NotEqual) => false,
141            (PredicateType::LowerBound, PredicateType::UpperBound) => rhs > rhs_other,
142            (PredicateType::UpperBound, PredicateType::LowerBound) => rhs_other > rhs,
143            (PredicateType::LowerBound, PredicateType::Equal) => rhs > rhs_other,
144            (PredicateType::Equal, PredicateType::LowerBound) => rhs_other > rhs,
145            (PredicateType::UpperBound, PredicateType::Equal) => rhs < rhs_other,
146            (PredicateType::Equal, PredicateType::UpperBound) => rhs_other < rhs,
147            (PredicateType::NotEqual, PredicateType::Equal)
148            | (PredicateType::Equal, PredicateType::NotEqual) => rhs == rhs_other,
149            (PredicateType::Equal, PredicateType::Equal) => rhs != rhs_other,
150        }
151    }
152    pub fn is_equality_predicate(&self) -> bool {
153        self.get_type_code() == EQUAL_CODE
154    }
155
156    pub fn is_lower_bound_predicate(&self) -> bool {
157        self.get_type_code() == LOWER_BOUND_CODE
158    }
159
160    pub fn is_upper_bound_predicate(&self) -> bool {
161        self.get_type_code() == UPPER_BOUND_CODE
162    }
163
164    pub fn is_not_equal_predicate(&self) -> bool {
165        self.get_type_code() == NOT_EQUAL_CODE
166    }
167
168    /// Returns the [`DomainId`] of the [`Predicate`]
169    pub fn get_domain(&self) -> DomainId {
170        DomainId::new(0b00111111_11111111_11111111_11111111 & self.id)
171    }
172
173    pub fn get_right_hand_side(&self) -> i32 {
174        self.value
175    }
176
177    pub fn trivially_true() -> Predicate {
178        // By convention, there is a dummy 0-1 variable set to one at root.
179        // We use it to denote the trivially true predicate.
180        let domain_id = DomainId::new(0);
181        predicate!(domain_id == 1)
182    }
183
184    pub fn trivially_false() -> Predicate {
185        // By convention, there is a dummy 0-1 variable set to one at root.
186        // We use it to denote the trivially true predicate.
187        let domain_id = DomainId::new(0);
188        predicate!(domain_id != 1)
189    }
190}
191
192impl std::ops::Not for Predicate {
193    type Output = Predicate;
194
195    fn not(self) -> Self::Output {
196        let domain_id = self.get_domain();
197        let value = self.get_right_hand_side();
198
199        match self.get_predicate_type() {
200            PredicateType::LowerBound => predicate!(domain_id <= value - 1),
201            PredicateType::UpperBound => predicate!(domain_id >= value + 1),
202            PredicateType::NotEqual => predicate!(domain_id == value),
203            PredicateType::Equal => predicate!(domain_id != value),
204        }
205    }
206}
207
208impl std::fmt::Display for Predicate {
209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210        if *self == Predicate::trivially_true() {
211            write!(f, "[True]")
212        } else if *self == Predicate::trivially_false() {
213            write!(f, "[False]")
214        } else {
215            let domain_id = self.get_domain();
216            let rhs = self.get_right_hand_side();
217
218            match self.get_predicate_type() {
219                PredicateType::LowerBound => write!(f, "[{domain_id} >= {rhs}]"),
220                PredicateType::UpperBound => write!(f, "[{domain_id} <= {rhs}]"),
221                PredicateType::NotEqual => write!(f, "[{domain_id} != {rhs}]"),
222                PredicateType::Equal => write!(f, "[{domain_id} == {rhs}]"),
223            }
224        }
225    }
226}
227
228impl std::fmt::Debug for Predicate {
229    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230        write!(f, "{self}")
231    }
232}
233
234#[cfg(test)]
235mod test {
236    use super::Predicate;
237    use crate::predicate;
238    use crate::variables::DomainId;
239
240    #[test]
241    fn are_mutually_exclusive() {
242        let domain_id = DomainId::new(0);
243
244        assert!(!predicate!(domain_id >= 5).is_mutually_exclusive_with(predicate!(domain_id >= 7)));
245        assert!(!predicate!(domain_id >= 5).is_mutually_exclusive_with(predicate!(domain_id != 2)));
246        assert!(!predicate!(domain_id <= 5).is_mutually_exclusive_with(predicate!(domain_id <= 8)));
247        assert!(!predicate!(domain_id <= 5).is_mutually_exclusive_with(predicate!(domain_id != 8)));
248        assert!(!predicate!(domain_id != 9).is_mutually_exclusive_with(predicate!(domain_id >= 8)));
249        assert!(!predicate!(domain_id != 9).is_mutually_exclusive_with(predicate!(domain_id <= 8)));
250        assert!(!predicate!(domain_id != 9).is_mutually_exclusive_with(predicate!(domain_id != 8)));
251
252        assert!(predicate!(domain_id <= 7).is_mutually_exclusive_with(predicate!(domain_id >= 8)));
253        assert!(predicate!(domain_id >= 8).is_mutually_exclusive_with(predicate!(domain_id <= 7)));
254
255        assert!(predicate!(domain_id >= 8).is_mutually_exclusive_with(predicate!(domain_id == 7)));
256        assert!(predicate!(domain_id == 7).is_mutually_exclusive_with(predicate!(domain_id >= 8)));
257
258        assert!(predicate!(domain_id == 7).is_mutually_exclusive_with(predicate!(domain_id <= 6)));
259        assert!(predicate!(domain_id <= 6).is_mutually_exclusive_with(predicate!(domain_id == 7)));
260
261        assert!(predicate!(domain_id != 8).is_mutually_exclusive_with(predicate!(domain_id == 8)));
262        assert!(predicate!(domain_id == 8).is_mutually_exclusive_with(predicate!(domain_id != 8)));
263
264        assert!(predicate!(domain_id == 7).is_mutually_exclusive_with(predicate!(domain_id == 8)));
265    }
266
267    #[test]
268    fn negating_trivially_true_predicate() {
269        let trivially_true = Predicate::trivially_true();
270        let trivially_false = Predicate::trivially_false();
271        assert!(!trivially_true == trivially_false);
272    }
273
274    #[test]
275    fn negating_trivially_false_predicate() {
276        let trivially_true = Predicate::trivially_true();
277        let trivially_false = Predicate::trivially_false();
278        assert!(!trivially_false == trivially_true);
279    }
280}