pumpkin_core/engine/variables/
affine_view.rs

1use std::cmp::Ordering;
2
3use enumset::EnumSet;
4
5use super::TransformableVariable;
6use crate::engine::notifications::DomainEvent;
7use crate::engine::notifications::OpaqueDomainEvent;
8use crate::engine::notifications::Watchers;
9use crate::engine::predicates::predicate::Predicate;
10use crate::engine::predicates::predicate_constructor::PredicateConstructor;
11use crate::engine::variables::DomainId;
12use crate::engine::variables::IntegerVariable;
13use crate::engine::Assignments;
14use crate::math::num_ext::NumExt;
15
16/// Models the constraint `y = ax + b`, by expressing the domain of `y` as a transformation of the
17/// domain of `x`.
18#[derive(Clone, Copy, Hash, Eq, PartialEq)]
19pub struct AffineView<Inner> {
20    inner: Inner,
21    scale: i32,
22    offset: i32,
23}
24
25impl<Inner> AffineView<Inner> {
26    pub fn new(inner: Inner, scale: i32, offset: i32) -> Self {
27        AffineView {
28            inner,
29            scale,
30            offset,
31        }
32    }
33
34    /// Apply the inverse transformation of this view on a value, to go from the value in the domain
35    /// of `self` to a value in the domain of `self.inner`.
36    fn invert(&self, value: i32, rounding: Rounding) -> i32 {
37        let inverted_translation = value - self.offset;
38
39        match rounding {
40            Rounding::Up => <i32 as NumExt>::div_ceil(inverted_translation, self.scale),
41            Rounding::Down => <i32 as NumExt>::div_floor(inverted_translation, self.scale),
42        }
43    }
44
45    fn map(&self, value: i32) -> i32 {
46        self.scale * value + self.offset
47    }
48}
49
50impl<View> IntegerVariable for AffineView<View>
51where
52    View: IntegerVariable,
53{
54    type AffineView = Self;
55
56    fn lower_bound(&self, assignment: &Assignments) -> i32 {
57        if self.scale < 0 {
58            self.map(self.inner.upper_bound(assignment))
59        } else {
60            self.map(self.inner.lower_bound(assignment))
61        }
62    }
63
64    fn lower_bound_at_trail_position(
65        &self,
66        assignment: &Assignments,
67        trail_position: usize,
68    ) -> i32 {
69        if self.scale < 0 {
70            self.map(
71                self.inner
72                    .upper_bound_at_trail_position(assignment, trail_position),
73            )
74        } else {
75            self.map(
76                self.inner
77                    .lower_bound_at_trail_position(assignment, trail_position),
78            )
79        }
80    }
81
82    fn upper_bound(&self, assignment: &Assignments) -> i32 {
83        if self.scale < 0 {
84            self.map(self.inner.lower_bound(assignment))
85        } else {
86            self.map(self.inner.upper_bound(assignment))
87        }
88    }
89
90    fn upper_bound_at_trail_position(
91        &self,
92        assignment: &Assignments,
93        trail_position: usize,
94    ) -> i32 {
95        if self.scale < 0 {
96            self.map(
97                self.inner
98                    .lower_bound_at_trail_position(assignment, trail_position),
99            )
100        } else {
101            self.map(
102                self.inner
103                    .upper_bound_at_trail_position(assignment, trail_position),
104            )
105        }
106    }
107
108    fn contains(&self, assignment: &Assignments, value: i32) -> bool {
109        if (value - self.offset) % self.scale == 0 {
110            let inverted = self.invert(value, Rounding::Up);
111            self.inner.contains(assignment, inverted)
112        } else {
113            false
114        }
115    }
116
117    fn contains_at_trail_position(
118        &self,
119        assignment: &Assignments,
120        value: i32,
121        trail_position: usize,
122    ) -> bool {
123        if (value - self.offset) % self.scale == 0 {
124            let inverted = self.invert(value, Rounding::Up);
125            self.inner
126                .contains_at_trail_position(assignment, inverted, trail_position)
127        } else {
128            false
129        }
130    }
131
132    fn iterate_domain(&self, assignment: &Assignments) -> impl Iterator<Item = i32> {
133        self.inner
134            .iterate_domain(assignment)
135            .map(|value| self.map(value))
136    }
137
138    fn watch_all(&self, watchers: &mut Watchers<'_>, mut events: EnumSet<DomainEvent>) {
139        let bound = DomainEvent::LowerBound | DomainEvent::UpperBound;
140        let intersection = events.intersection(bound);
141        if intersection.len() == 1 && self.scale.is_negative() {
142            events = events.symmetrical_difference(bound);
143        }
144        self.inner.watch_all(watchers, events);
145    }
146
147    fn watch_all_backtrack(&self, watchers: &mut Watchers<'_>, mut events: EnumSet<DomainEvent>) {
148        let bound = DomainEvent::LowerBound | DomainEvent::UpperBound;
149        let intersection = events.intersection(bound);
150        if intersection.len() == 1 && self.scale.is_negative() {
151            events = events.symmetrical_difference(bound);
152        }
153        self.inner.watch_all_backtrack(watchers, events);
154    }
155
156    fn unpack_event(&self, event: OpaqueDomainEvent) -> DomainEvent {
157        if self.scale.is_negative() {
158            match self.inner.unpack_event(event) {
159                DomainEvent::LowerBound => DomainEvent::UpperBound,
160                DomainEvent::UpperBound => DomainEvent::LowerBound,
161                event => event,
162            }
163        } else {
164            self.inner.unpack_event(event)
165        }
166    }
167
168    fn get_holes_on_current_decision_level(
169        &self,
170        assignments: &Assignments,
171    ) -> impl Iterator<Item = i32> {
172        self.inner
173            .get_holes_on_current_decision_level(assignments)
174            .map(|value| self.map(value))
175    }
176
177    fn get_holes(&self, assignments: &Assignments) -> impl Iterator<Item = i32> {
178        self.inner
179            .get_holes(assignments)
180            .map(|value| self.map(value))
181    }
182}
183
184impl<View> TransformableVariable<AffineView<View>> for AffineView<View>
185where
186    View: IntegerVariable,
187{
188    fn scaled(&self, scale: i32) -> AffineView<View> {
189        let mut result = self.clone();
190        result.scale *= scale;
191        result.offset *= scale;
192        result
193    }
194
195    fn offset(&self, offset: i32) -> AffineView<View> {
196        let mut result = self.clone();
197        result.offset += offset;
198        result
199    }
200}
201
202impl<Var: std::fmt::Debug> std::fmt::Debug for AffineView<Var> {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        if self.scale == -1 {
205            write!(f, "-")?;
206        } else if self.scale != 1 {
207            write!(f, "{} * ", self.scale)?;
208        }
209
210        write!(f, "({:?})", self.inner)?;
211
212        match self.offset.cmp(&0) {
213            Ordering::Less => write!(f, " - {}", -self.offset)?,
214            Ordering::Equal => {}
215            Ordering::Greater => write!(f, " + {}", self.offset)?,
216        }
217
218        Ok(())
219    }
220}
221
222impl<Var: PredicateConstructor<Value = i32>> PredicateConstructor for AffineView<Var> {
223    type Value = Var::Value;
224
225    fn lower_bound_predicate(&self, bound: Self::Value) -> Predicate {
226        if self.scale < 0 {
227            let inverted_bound = self.invert(bound, Rounding::Down);
228            self.inner.upper_bound_predicate(inverted_bound)
229        } else {
230            let inverted_bound = self.invert(bound, Rounding::Up);
231            self.inner.lower_bound_predicate(inverted_bound)
232        }
233    }
234
235    fn upper_bound_predicate(&self, bound: Self::Value) -> Predicate {
236        if self.scale < 0 {
237            let inverted_bound = self.invert(bound, Rounding::Up);
238            self.inner.lower_bound_predicate(inverted_bound)
239        } else {
240            let inverted_bound = self.invert(bound, Rounding::Down);
241            self.inner.upper_bound_predicate(inverted_bound)
242        }
243    }
244
245    fn equality_predicate(&self, bound: Self::Value) -> Predicate {
246        if (bound - self.offset) % self.scale == 0 {
247            let inverted_bound = self.invert(bound, Rounding::Up);
248            self.inner.equality_predicate(inverted_bound)
249        } else {
250            Predicate::trivially_false()
251        }
252    }
253
254    fn disequality_predicate(&self, bound: Self::Value) -> Predicate {
255        if (bound - self.offset) % self.scale == 0 {
256            let inverted_bound = self.invert(bound, Rounding::Up);
257            self.inner.disequality_predicate(inverted_bound)
258        } else {
259            Predicate::trivially_true()
260        }
261    }
262}
263
264impl From<DomainId> for AffineView<DomainId> {
265    fn from(value: DomainId) -> Self {
266        AffineView::new(value, 1, 0)
267    }
268}
269
270enum Rounding {
271    Up,
272    Down,
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278    use crate::predicate;
279
280    #[test]
281    fn scaling_an_affine_view() {
282        let view = AffineView::new(DomainId::new(0), 3, 4);
283        assert_eq!(3, view.scale);
284        assert_eq!(4, view.offset);
285        let scaled_view = view.scaled(6);
286        assert_eq!(18, scaled_view.scale);
287        assert_eq!(24, scaled_view.offset);
288    }
289
290    #[test]
291    fn offsetting_an_affine_view() {
292        let view = AffineView::new(DomainId::new(0), 3, 4);
293        assert_eq!(3, view.scale);
294        assert_eq!(4, view.offset);
295        let scaled_view = view.offset(6);
296        assert_eq!(3, scaled_view.scale);
297        assert_eq!(10, scaled_view.offset);
298    }
299
300    #[test]
301    fn affine_view_obtaining_a_bound_should_round_optimistically_in_inner_domain() {
302        let domain = DomainId::new(0);
303        let view = AffineView::new(domain, 2, 0);
304
305        assert_eq!(predicate!(domain >= 1), predicate!(view >= 1));
306        assert_eq!(predicate!(domain >= -1), predicate!(view >= -3));
307        assert_eq!(predicate!(domain <= 0), predicate!(view <= 1));
308        assert_eq!(predicate!(domain <= -3), predicate!(view <= -5));
309    }
310
311    #[test]
312    fn test_negated_variable_has_bounds_rounded_correctly() {
313        let domain = DomainId::new(0);
314        let view = AffineView::new(domain, -2, 0);
315
316        assert_eq!(predicate!(view <= -3), predicate!(domain >= 2));
317        assert_eq!(predicate!(view >= 5), predicate!(domain <= -3));
318    }
319}