pumpkin_core/engine/predicates/
predicate_constructor.rs

1use super::predicate::Predicate;
2use super::predicate::PredicateType;
3use crate::engine::variables::DomainId;
4
5/// A trait which defines methods for creating a [`Predicate`].
6pub trait PredicateConstructor {
7    /// The value used to represent a bound.
8    type Value;
9
10    /// Creates a lower-bound predicate (e.g. `[x >= v]`).
11    fn lower_bound_predicate(&self, bound: Self::Value) -> Predicate;
12
13    /// Creates an upper-bound predicate (e.g. `[x <= v]`).
14    fn upper_bound_predicate(&self, bound: Self::Value) -> Predicate;
15
16    /// Creates an equality predicate (e.g. `[x == v]`).
17    fn equality_predicate(&self, bound: Self::Value) -> Predicate;
18
19    /// Creates a disequality predicate (e.g. `[x != v]`).
20    fn disequality_predicate(&self, bound: Self::Value) -> Predicate;
21}
22
23impl PredicateConstructor for DomainId {
24    type Value = i32;
25
26    fn equality_predicate(&self, bound: Self::Value) -> Predicate {
27        Predicate::new(*self, PredicateType::Equal, bound)
28    }
29
30    fn lower_bound_predicate(&self, bound: Self::Value) -> Predicate {
31        Predicate::new(*self, PredicateType::LowerBound, bound)
32    }
33
34    fn upper_bound_predicate(&self, bound: Self::Value) -> Predicate {
35        Predicate::new(*self, PredicateType::UpperBound, bound)
36    }
37
38    fn disequality_predicate(&self, bound: Self::Value) -> Predicate {
39        Predicate::new(*self, PredicateType::NotEqual, bound)
40    }
41}
42
43/// A macro which allows for the creation of a [`Predicate`].
44///
45/// # Example
46/// ```rust
47/// # use pumpkin_core::Solver;
48/// # use pumpkin_core::predicate;
49/// # use pumpkin_core::predicates::Predicate;
50/// let mut solver = Solver::default();
51/// let x = solver.new_bounded_integer(0, 10);
52///
53/// let lower_bound_predicate = predicate!(x >= 5);
54/// assert_eq!(lower_bound_predicate.get_domain(), x);
55/// assert_eq!(lower_bound_predicate.get_right_hand_side(), 5);
56///
57/// let upper_bound_predicate = predicate!(x <= 5);
58/// assert_eq!(upper_bound_predicate.get_domain(), x);
59/// assert_eq!(upper_bound_predicate.get_right_hand_side(), 5);
60///
61/// let equality_predicate = predicate!(x == 5);
62/// assert_eq!(equality_predicate.get_domain(), x);
63/// assert_eq!(equality_predicate.get_right_hand_side(), 5);
64///
65/// let disequality_predicate = predicate!(x != 5);
66/// assert_eq!(disequality_predicate.get_domain(), x);
67/// assert_eq!(disequality_predicate.get_right_hand_side(), 5);
68/// ```
69#[macro_export]
70macro_rules! predicate {
71    ($($var:ident).+$([$index:expr])? >= $bound:expr) => {{
72        #[allow(unused, reason = "could be imported at call-site")]
73        use $crate::predicates::PredicateConstructor;
74        $($var).+$([$index])?.lower_bound_predicate($bound)
75    }};
76    ($($var:ident).+$([$index:expr])? <= $bound:expr) => {{
77        #[allow(unused, reason = "could be imported at call-site")]
78        use $crate::predicates::PredicateConstructor;
79        $($var).+$([$index])?.upper_bound_predicate($bound)
80    }};
81    ($($var:ident).+$([$index:expr])? == $value:expr) => {{
82        #[allow(unused, reason = "could be imported at call-site")]
83        use $crate::predicates::PredicateConstructor;
84        $($var).+$([$index])?.equality_predicate($value)
85    }};
86    ($($var:ident).+$([$index:expr])? != $value:expr) => {{
87        #[allow(unused, reason = "could be imported at call-site")]
88        use $crate::predicates::PredicateConstructor;
89        $($var).+$([$index])?.disequality_predicate($value)
90    }};
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96
97    #[test]
98    fn macro_local_identifiers_are_matched() {
99        let x = DomainId::new(0);
100
101        assert_eq!(x, predicate![x >= 2].get_domain());
102        assert_eq!(x, predicate![x <= 3].get_domain());
103        assert_eq!(x, predicate![x == 5].get_domain());
104        assert_eq!(x, predicate![x != 5].get_domain());
105
106        assert_eq!(2, predicate![x >= 2].get_right_hand_side());
107        assert_eq!(3, predicate![x <= 3].get_right_hand_side());
108        assert_eq!(5, predicate![x == 5].get_right_hand_side());
109        assert_eq!(5, predicate![x != 5].get_right_hand_side());
110
111        assert!(predicate!(x >= 2).is_lower_bound_predicate());
112        assert!(!predicate!(x >= 2).is_upper_bound_predicate());
113        assert!(!predicate!(x >= 2).is_equality_predicate());
114        assert!(!predicate!(x >= 2).is_not_equal_predicate());
115
116        assert!(predicate!(x <= 3).is_upper_bound_predicate());
117        assert!(!predicate!(x <= 3).is_lower_bound_predicate());
118        assert!(!predicate!(x <= 3).is_equality_predicate());
119        assert!(!predicate!(x <= 3).is_not_equal_predicate());
120
121        assert!(predicate!(x == 5).is_equality_predicate());
122        assert!(!predicate!(x == 5).is_lower_bound_predicate());
123        assert!(!predicate!(x == 5).is_upper_bound_predicate());
124        assert!(!predicate!(x == 5).is_not_equal_predicate());
125
126        assert!(predicate!(x != 5).is_not_equal_predicate());
127        assert!(!predicate!(x != 5).is_lower_bound_predicate());
128        assert!(!predicate!(x != 5).is_upper_bound_predicate());
129        assert!(!predicate!(x != 5).is_equality_predicate());
130    }
131
132    #[test]
133    fn macro_nested_identifiers_are_matched() {
134        struct Wrapper {
135            x: DomainId,
136        }
137
138        let wrapper = Wrapper {
139            x: DomainId::new(0),
140        };
141
142        assert_eq!(wrapper.x, predicate![wrapper.x >= 2].get_domain());
143        assert_eq!(wrapper.x, predicate![wrapper.x <= 3].get_domain());
144        assert_eq!(wrapper.x, predicate![wrapper.x == 5].get_domain());
145        assert_eq!(wrapper.x, predicate![wrapper.x != 5].get_domain());
146
147        assert_eq!(2, predicate![wrapper.x >= 2].get_right_hand_side());
148        assert_eq!(3, predicate![wrapper.x <= 3].get_right_hand_side());
149        assert_eq!(5, predicate![wrapper.x == 5].get_right_hand_side());
150        assert_eq!(5, predicate![wrapper.x != 5].get_right_hand_side());
151    }
152
153    #[test]
154    fn macro_index_expressions_are_matched() {
155        let wrapper = [DomainId::new(0)];
156
157        assert_eq!(wrapper[0], predicate![wrapper[0] >= 2].get_domain());
158        assert_eq!(wrapper[0], predicate![wrapper[0] <= 3].get_domain());
159        assert_eq!(wrapper[0], predicate![wrapper[0] == 5].get_domain());
160        assert_eq!(wrapper[0], predicate![wrapper[0] != 5].get_domain());
161
162        assert_eq!(2, predicate![wrapper[0] >= 2].get_right_hand_side());
163        assert_eq!(3, predicate![wrapper[0] <= 3].get_right_hand_side());
164        assert_eq!(5, predicate![wrapper[0] == 5].get_right_hand_side());
165        assert_eq!(5, predicate![wrapper[0] != 5].get_right_hand_side());
166    }
167}