pumpkin_core/engine/predicates/
predicate.rs

1use crate::engine::variables::DomainId;
2use crate::engine::Assignments;
3use crate::predicate;
4
5/// Representation of a domain operation, also known as an atomic constraint. It is a triple
6/// ([`DomainId`], [`PredicateType`], value).
7///
8/// To create a [`Predicate`], use [Predicate::new] or the more concise [predicate!] macro.
9#[derive(Clone, PartialEq, Eq, Copy, Hash)]
10pub struct Predicate {
11    /// The two most significant bits of the id stored in the [`Predicate`] contains the type of
12    /// predicate.
13    id: u32,
14    value: i32,
15}
16
17const LOWER_BOUND_CODE: u8 = 1;
18const UPPER_BOUND_CODE: u8 = 2;
19const EQUAL_CODE: u8 = 0;
20const NOT_EQUAL_CODE: u8 = 3;
21
22impl Predicate {
23    /// Creates a new [`Predicate`] (also known as atomic constraint) which represents a domain
24    /// operation.
25    pub fn new(id: DomainId, predicate_type: PredicateType, value: i32) -> Self {
26        let code = match predicate_type {
27            PredicateType::LowerBound => LOWER_BOUND_CODE,
28            PredicateType::UpperBound => UPPER_BOUND_CODE,
29            PredicateType::NotEqual => NOT_EQUAL_CODE,
30            PredicateType::Equal => EQUAL_CODE,
31        };
32        let id = id.id() | (code as u32) << 30;
33        Self { id, value }
34    }
35
36    fn get_type_code(&self) -> u8 {
37        (self.id >> 30) as u8
38    }
39
40    pub fn get_predicate_type(&self) -> PredicateType {
41        (*self).into()
42    }
43}
44
45#[derive(Debug, Clone, Eq, PartialEq, Copy, Hash)]
46pub enum PredicateType {
47    LowerBound,
48    UpperBound,
49    NotEqual,
50    Equal,
51}
52impl PredicateType {
53    pub(crate) fn is_lower_bound(&self) -> bool {
54        matches!(self, PredicateType::LowerBound)
55    }
56
57    pub(crate) fn is_upper_bound(&self) -> bool {
58        matches!(self, PredicateType::UpperBound)
59    }
60
61    pub(crate) fn into_predicate(
62        self,
63        domain: DomainId,
64        assignments: &Assignments,
65        removed_value: Option<i32>,
66    ) -> Predicate {
67        match self {
68            PredicateType::LowerBound => {
69                predicate!(domain >= assignments.get_lower_bound(domain))
70            }
71            PredicateType::UpperBound => predicate!(domain <= assignments.get_upper_bound(domain)),
72            PredicateType::NotEqual => predicate!(
73                domain
74                    != removed_value
75                        .expect("For a `NotEqual`, the removed value should be provided")
76            ),
77            PredicateType::Equal => predicate!(
78                domain
79                    == assignments.get_assigned_value(&domain).expect(
80                        "Expected domain to be assigned when creating an `Equal` predicate"
81                    )
82            ),
83        }
84    }
85}
86
87impl From<Predicate> for PredicateType {
88    fn from(value: Predicate) -> Self {
89        match value.get_type_code() {
90            LOWER_BOUND_CODE => Self::LowerBound,
91            UPPER_BOUND_CODE => Self::UpperBound,
92            EQUAL_CODE => Self::Equal,
93            NOT_EQUAL_CODE => Self::NotEqual,
94            code => panic!("Unknown type code {code}"),
95        }
96    }
97}
98
99impl Predicate {
100    pub(crate) fn is_mutually_exclusive_with(self, other: Predicate) -> bool {
101        let domain_id = self.get_domain();
102        let rhs = self.get_right_hand_side();
103
104        let domain_id_other = other.get_domain();
105        let rhs_other = other.get_right_hand_side();
106
107        if domain_id != domain_id_other {
108            // Domain Ids do not match
109            return false;
110        }
111
112        match (self.get_predicate_type(), other.get_predicate_type()) {
113            (PredicateType::LowerBound, PredicateType::LowerBound)
114            | (PredicateType::LowerBound, PredicateType::NotEqual)
115            | (PredicateType::UpperBound, PredicateType::UpperBound)
116            | (PredicateType::UpperBound, PredicateType::NotEqual)
117            | (PredicateType::NotEqual, PredicateType::LowerBound)
118            | (PredicateType::NotEqual, PredicateType::UpperBound)
119            | (PredicateType::NotEqual, PredicateType::NotEqual) => false,
120            (PredicateType::LowerBound, PredicateType::UpperBound) => rhs > rhs_other,
121            (PredicateType::UpperBound, PredicateType::LowerBound) => rhs_other > rhs,
122            (PredicateType::LowerBound, PredicateType::Equal) => rhs > rhs_other,
123            (PredicateType::Equal, PredicateType::LowerBound) => rhs_other > rhs,
124            (PredicateType::UpperBound, PredicateType::Equal) => rhs < rhs_other,
125            (PredicateType::Equal, PredicateType::UpperBound) => rhs_other < rhs,
126            (PredicateType::NotEqual, PredicateType::Equal)
127            | (PredicateType::Equal, PredicateType::NotEqual) => rhs == rhs_other,
128            (PredicateType::Equal, PredicateType::Equal) => rhs != rhs_other,
129        }
130    }
131    pub fn is_equality_predicate(&self) -> bool {
132        self.get_type_code() == EQUAL_CODE
133    }
134
135    pub fn is_lower_bound_predicate(&self) -> bool {
136        self.get_type_code() == LOWER_BOUND_CODE
137    }
138
139    pub fn is_upper_bound_predicate(&self) -> bool {
140        self.get_type_code() == UPPER_BOUND_CODE
141    }
142
143    pub fn is_not_equal_predicate(&self) -> bool {
144        self.get_type_code() == NOT_EQUAL_CODE
145    }
146
147    /// Returns the [`DomainId`] of the [`Predicate`]
148    pub fn get_domain(&self) -> DomainId {
149        DomainId::new(0b00111111_11111111_11111111_11111111 & self.id)
150    }
151
152    pub fn get_right_hand_side(&self) -> i32 {
153        self.value
154    }
155
156    pub fn trivially_true() -> Predicate {
157        // By convention, there is a dummy 0-1 variable set to one at root.
158        // We use it to denote the trivially true predicate.
159        let domain_id = DomainId::new(0);
160        predicate!(domain_id == 1)
161    }
162
163    pub fn trivially_false() -> Predicate {
164        // By convention, there is a dummy 0-1 variable set to one at root.
165        // We use it to denote the trivially true predicate.
166        let domain_id = DomainId::new(0);
167        predicate!(domain_id != 1)
168    }
169}
170
171impl std::ops::Not for Predicate {
172    type Output = Predicate;
173
174    fn not(self) -> Self::Output {
175        let domain_id = self.get_domain();
176        let value = self.get_right_hand_side();
177
178        match self.get_predicate_type() {
179            PredicateType::LowerBound => predicate!(domain_id <= value - 1),
180            PredicateType::UpperBound => predicate!(domain_id >= value + 1),
181            PredicateType::NotEqual => predicate!(domain_id == value),
182            PredicateType::Equal => predicate!(domain_id != value),
183        }
184    }
185}
186
187impl std::fmt::Display for Predicate {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        if *self == Predicate::trivially_true() {
190            write!(f, "[True]")
191        } else if *self == Predicate::trivially_false() {
192            write!(f, "[False]")
193        } else {
194            let domain_id = self.get_domain();
195            let rhs = self.get_right_hand_side();
196
197            match self.get_predicate_type() {
198                PredicateType::LowerBound => write!(f, "[{domain_id} >= {rhs}]"),
199                PredicateType::UpperBound => write!(f, "[{domain_id} <= {rhs}]"),
200                PredicateType::NotEqual => write!(f, "[{domain_id} != {rhs}]"),
201                PredicateType::Equal => write!(f, "[{domain_id} == {rhs}]"),
202            }
203        }
204    }
205}
206
207impl std::fmt::Debug for Predicate {
208    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209        write!(f, "{self}")
210    }
211}
212
213#[cfg(test)]
214mod test {
215    use super::Predicate;
216    use crate::predicate;
217    use crate::variables::DomainId;
218
219    #[test]
220    fn are_mutually_exclusive() {
221        let domain_id = DomainId::new(0);
222
223        assert!(!predicate!(domain_id >= 5).is_mutually_exclusive_with(predicate!(domain_id >= 7)));
224        assert!(!predicate!(domain_id >= 5).is_mutually_exclusive_with(predicate!(domain_id != 2)));
225        assert!(!predicate!(domain_id <= 5).is_mutually_exclusive_with(predicate!(domain_id <= 8)));
226        assert!(!predicate!(domain_id <= 5).is_mutually_exclusive_with(predicate!(domain_id != 8)));
227        assert!(!predicate!(domain_id != 9).is_mutually_exclusive_with(predicate!(domain_id >= 8)));
228        assert!(!predicate!(domain_id != 9).is_mutually_exclusive_with(predicate!(domain_id <= 8)));
229        assert!(!predicate!(domain_id != 9).is_mutually_exclusive_with(predicate!(domain_id != 8)));
230
231        assert!(predicate!(domain_id <= 7).is_mutually_exclusive_with(predicate!(domain_id >= 8)));
232        assert!(predicate!(domain_id >= 8).is_mutually_exclusive_with(predicate!(domain_id <= 7)));
233
234        assert!(predicate!(domain_id >= 8).is_mutually_exclusive_with(predicate!(domain_id == 7)));
235        assert!(predicate!(domain_id == 7).is_mutually_exclusive_with(predicate!(domain_id >= 8)));
236
237        assert!(predicate!(domain_id == 7).is_mutually_exclusive_with(predicate!(domain_id <= 6)));
238        assert!(predicate!(domain_id <= 6).is_mutually_exclusive_with(predicate!(domain_id == 7)));
239
240        assert!(predicate!(domain_id != 8).is_mutually_exclusive_with(predicate!(domain_id == 8)));
241        assert!(predicate!(domain_id == 8).is_mutually_exclusive_with(predicate!(domain_id != 8)));
242
243        assert!(predicate!(domain_id == 7).is_mutually_exclusive_with(predicate!(domain_id == 8)));
244    }
245
246    #[test]
247    fn negating_trivially_true_predicate() {
248        let trivially_true = Predicate::trivially_true();
249        let trivially_false = Predicate::trivially_false();
250        assert!(!trivially_true == trivially_false);
251    }
252
253    #[test]
254    fn negating_trivially_false_predicate() {
255        let trivially_true = Predicate::trivially_true();
256        let trivially_false = Predicate::trivially_false();
257        assert!(!trivially_false == trivially_true);
258    }
259}