pumpkin_core/engine/variables/
affine_view.rs

1use std::cmp::Ordering;
2
3use enumset::EnumSet;
4
5use super::TransformableVariable;
6use crate::engine::notifications::DomainEvent;
7use crate::engine::notifications::OpaqueDomainEvent;
8use crate::engine::notifications::Watchers;
9use crate::engine::predicates::predicate::Predicate;
10use crate::engine::predicates::predicate_constructor::PredicateConstructor;
11use crate::engine::variables::DomainId;
12use crate::engine::variables::IntegerVariable;
13use crate::engine::Assignments;
14use crate::math::num_ext::NumExt;
15
16/// Models the constraint `y = ax + b`, by expressing the domain of `y` as a transformation of the
17/// domain of `x`.
18#[derive(Clone, Copy, Hash, Eq, PartialEq)]
19pub struct AffineView<Inner> {
20    inner: Inner,
21    scale: i32,
22    offset: i32,
23}
24
25impl<Inner> AffineView<Inner> {
26    pub fn new(inner: Inner, scale: i32, offset: i32) -> Self {
27        assert_ne!(scale, 0, "Multiplication by zero is not invertable");
28        AffineView {
29            inner,
30            scale,
31            offset,
32        }
33    }
34
35    /// Apply the inverse transformation of this view on a value, to go from the value in the domain
36    /// of `self` to a value in the domain of `self.inner`.
37    fn invert(&self, value: i32, rounding: Rounding) -> i32 {
38        let inverted_translation = value - self.offset;
39
40        match rounding {
41            Rounding::Up => <i32 as NumExt>::div_ceil(inverted_translation, self.scale),
42            Rounding::Down => <i32 as NumExt>::div_floor(inverted_translation, self.scale),
43        }
44    }
45
46    fn map(&self, value: i32) -> i32 {
47        self.scale * value + self.offset
48    }
49}
50
51impl<View> IntegerVariable for AffineView<View>
52where
53    View: IntegerVariable,
54{
55    type AffineView = Self;
56
57    fn lower_bound(&self, assignment: &Assignments) -> i32 {
58        if self.scale < 0 {
59            self.map(self.inner.upper_bound(assignment))
60        } else {
61            self.map(self.inner.lower_bound(assignment))
62        }
63    }
64
65    fn lower_bound_at_trail_position(
66        &self,
67        assignment: &Assignments,
68        trail_position: usize,
69    ) -> i32 {
70        if self.scale < 0 {
71            self.map(
72                self.inner
73                    .upper_bound_at_trail_position(assignment, trail_position),
74            )
75        } else {
76            self.map(
77                self.inner
78                    .lower_bound_at_trail_position(assignment, trail_position),
79            )
80        }
81    }
82
83    fn upper_bound(&self, assignment: &Assignments) -> i32 {
84        if self.scale < 0 {
85            self.map(self.inner.lower_bound(assignment))
86        } else {
87            self.map(self.inner.upper_bound(assignment))
88        }
89    }
90
91    fn upper_bound_at_trail_position(
92        &self,
93        assignment: &Assignments,
94        trail_position: usize,
95    ) -> i32 {
96        if self.scale < 0 {
97            self.map(
98                self.inner
99                    .lower_bound_at_trail_position(assignment, trail_position),
100            )
101        } else {
102            self.map(
103                self.inner
104                    .upper_bound_at_trail_position(assignment, trail_position),
105            )
106        }
107    }
108
109    fn contains(&self, assignment: &Assignments, value: i32) -> bool {
110        if (value - self.offset) % self.scale == 0 {
111            let inverted = self.invert(value, Rounding::Up);
112            self.inner.contains(assignment, inverted)
113        } else {
114            false
115        }
116    }
117
118    fn contains_at_trail_position(
119        &self,
120        assignment: &Assignments,
121        value: i32,
122        trail_position: usize,
123    ) -> bool {
124        if (value - self.offset) % self.scale == 0 {
125            let inverted = self.invert(value, Rounding::Up);
126            self.inner
127                .contains_at_trail_position(assignment, inverted, trail_position)
128        } else {
129            false
130        }
131    }
132
133    fn iterate_domain(&self, assignment: &Assignments) -> impl Iterator<Item = i32> {
134        self.inner
135            .iterate_domain(assignment)
136            .map(|value| self.map(value))
137    }
138
139    fn watch_all(&self, watchers: &mut Watchers<'_>, mut events: EnumSet<DomainEvent>) {
140        let bound = DomainEvent::LowerBound | DomainEvent::UpperBound;
141        let intersection = events.intersection(bound);
142        if intersection.len() == 1 && self.scale.is_negative() {
143            events = events.symmetrical_difference(bound);
144        }
145        self.inner.watch_all(watchers, events);
146    }
147
148    fn watch_all_backtrack(&self, watchers: &mut Watchers<'_>, mut events: EnumSet<DomainEvent>) {
149        let bound = DomainEvent::LowerBound | DomainEvent::UpperBound;
150        let intersection = events.intersection(bound);
151        if intersection.len() == 1 && self.scale.is_negative() {
152            events = events.symmetrical_difference(bound);
153        }
154        self.inner.watch_all_backtrack(watchers, events);
155    }
156
157    fn unpack_event(&self, event: OpaqueDomainEvent) -> DomainEvent {
158        if self.scale.is_negative() {
159            match self.inner.unpack_event(event) {
160                DomainEvent::LowerBound => DomainEvent::UpperBound,
161                DomainEvent::UpperBound => DomainEvent::LowerBound,
162                event => event,
163            }
164        } else {
165            self.inner.unpack_event(event)
166        }
167    }
168
169    fn get_holes_on_current_decision_level(
170        &self,
171        assignments: &Assignments,
172    ) -> impl Iterator<Item = i32> {
173        self.inner
174            .get_holes_on_current_decision_level(assignments)
175            .map(|value| self.map(value))
176    }
177
178    fn get_holes(&self, assignments: &Assignments) -> impl Iterator<Item = i32> {
179        self.inner
180            .get_holes(assignments)
181            .map(|value| self.map(value))
182    }
183}
184
185impl<View> TransformableVariable<AffineView<View>> for AffineView<View>
186where
187    View: IntegerVariable,
188{
189    fn scaled(&self, scale: i32) -> AffineView<View> {
190        let mut result = self.clone();
191        result.scale *= scale;
192        result.offset *= scale;
193        result
194    }
195
196    fn offset(&self, offset: i32) -> AffineView<View> {
197        let mut result = self.clone();
198        result.offset += offset;
199        result
200    }
201}
202
203impl<Var: std::fmt::Debug> std::fmt::Debug for AffineView<Var> {
204    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205        if self.scale == -1 {
206            write!(f, "-")?;
207        } else if self.scale != 1 {
208            write!(f, "{} * ", self.scale)?;
209        }
210
211        write!(f, "({:?})", self.inner)?;
212
213        match self.offset.cmp(&0) {
214            Ordering::Less => write!(f, " - {}", -self.offset)?,
215            Ordering::Equal => {}
216            Ordering::Greater => write!(f, " + {}", self.offset)?,
217        }
218
219        Ok(())
220    }
221}
222
223impl<Var: PredicateConstructor<Value = i32>> PredicateConstructor for AffineView<Var> {
224    type Value = Var::Value;
225
226    fn lower_bound_predicate(&self, bound: Self::Value) -> Predicate {
227        if self.scale < 0 {
228            let inverted_bound = self.invert(bound, Rounding::Down);
229            self.inner.upper_bound_predicate(inverted_bound)
230        } else {
231            let inverted_bound = self.invert(bound, Rounding::Up);
232            self.inner.lower_bound_predicate(inverted_bound)
233        }
234    }
235
236    fn upper_bound_predicate(&self, bound: Self::Value) -> Predicate {
237        if self.scale < 0 {
238            let inverted_bound = self.invert(bound, Rounding::Up);
239            self.inner.lower_bound_predicate(inverted_bound)
240        } else {
241            let inverted_bound = self.invert(bound, Rounding::Down);
242            self.inner.upper_bound_predicate(inverted_bound)
243        }
244    }
245
246    fn equality_predicate(&self, bound: Self::Value) -> Predicate {
247        if (bound - self.offset) % self.scale == 0 {
248            let inverted_bound = self.invert(bound, Rounding::Up);
249            self.inner.equality_predicate(inverted_bound)
250        } else {
251            Predicate::trivially_false()
252        }
253    }
254
255    fn disequality_predicate(&self, bound: Self::Value) -> Predicate {
256        if (bound - self.offset) % self.scale == 0 {
257            let inverted_bound = self.invert(bound, Rounding::Up);
258            self.inner.disequality_predicate(inverted_bound)
259        } else {
260            Predicate::trivially_true()
261        }
262    }
263}
264
265impl From<DomainId> for AffineView<DomainId> {
266    fn from(value: DomainId) -> Self {
267        AffineView::new(value, 1, 0)
268    }
269}
270
271enum Rounding {
272    Up,
273    Down,
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use crate::predicate;
280
281    #[test]
282    fn scaling_an_affine_view() {
283        let view = AffineView::new(DomainId::new(0), 3, 4);
284        assert_eq!(3, view.scale);
285        assert_eq!(4, view.offset);
286        let scaled_view = view.scaled(6);
287        assert_eq!(18, scaled_view.scale);
288        assert_eq!(24, scaled_view.offset);
289    }
290
291    #[test]
292    fn offsetting_an_affine_view() {
293        let view = AffineView::new(DomainId::new(0), 3, 4);
294        assert_eq!(3, view.scale);
295        assert_eq!(4, view.offset);
296        let scaled_view = view.offset(6);
297        assert_eq!(3, scaled_view.scale);
298        assert_eq!(10, scaled_view.offset);
299    }
300
301    #[test]
302    fn affine_view_obtaining_a_bound_should_round_optimistically_in_inner_domain() {
303        let domain = DomainId::new(0);
304        let view = AffineView::new(domain, 2, 0);
305
306        assert_eq!(predicate!(domain >= 1), predicate!(view >= 1));
307        assert_eq!(predicate!(domain >= -1), predicate!(view >= -3));
308        assert_eq!(predicate!(domain <= 0), predicate!(view <= 1));
309        assert_eq!(predicate!(domain <= -3), predicate!(view <= -5));
310    }
311
312    #[test]
313    fn test_negated_variable_has_bounds_rounded_correctly() {
314        let domain = DomainId::new(0);
315        let view = AffineView::new(domain, -2, 0);
316
317        assert_eq!(predicate!(view <= -3), predicate!(domain >= 2));
318        assert_eq!(predicate!(view >= 5), predicate!(domain <= -3));
319    }
320}