Skip to main content

pumpkin_core/engine/variables/
affine_view.rs

1use std::cmp::Ordering;
2
3use enumset::EnumSet;
4use pumpkin_checking::CheckerVariable;
5use pumpkin_checking::IntExt;
6
7use super::TransformableVariable;
8use crate::engine::Assignments;
9use crate::engine::notifications::DomainEvent;
10use crate::engine::notifications::OpaqueDomainEvent;
11use crate::engine::notifications::Watchers;
12use crate::engine::predicates::predicate::Predicate;
13use crate::engine::predicates::predicate_constructor::PredicateConstructor;
14use crate::engine::variables::DomainId;
15use crate::engine::variables::IntegerVariable;
16use crate::math::num_ext::NumExt;
17
18/// Models the constraint `y = ax + b`, by expressing the domain of `y` as a transformation of the
19/// domain of `x`.
20#[derive(Clone, Copy, Hash, Eq, PartialEq)]
21pub struct AffineView<Inner> {
22    pub(crate) inner: Inner,
23    pub(crate) scale: i32,
24    pub(crate) offset: i32,
25}
26
27impl<Inner> AffineView<Inner> {
28    pub fn new(inner: Inner, scale: i32, offset: i32) -> Self {
29        assert_ne!(scale, 0, "Multiplication by zero is not invertable");
30        AffineView {
31            inner,
32            scale,
33            offset,
34        }
35    }
36
37    pub fn inner(&self) -> &Inner {
38        &self.inner
39    }
40
41    /// Apply the inverse transformation of this view on a value, to go from the value in the domain
42    /// of `self` to a value in the domain of `self.inner`.
43    fn invert(&self, value: i32, rounding: Rounding) -> i32 {
44        let inverted_translation = value - self.offset;
45
46        match rounding {
47            Rounding::Up => <i32 as NumExt>::div_ceil(inverted_translation, self.scale),
48            Rounding::Down => <i32 as NumExt>::div_floor(inverted_translation, self.scale),
49        }
50    }
51
52    fn map(&self, value: i32) -> i32 {
53        self.scale * value + self.offset
54    }
55}
56
57impl<Var: IntegerVariable> CheckerVariable<Predicate> for AffineView<Var> {
58    fn does_atomic_constrain_self(&self, atomic: &Predicate) -> bool {
59        self.inner.does_atomic_constrain_self(atomic)
60    }
61
62    fn atomic_less_than(&self, value: i32) -> Predicate {
63        use crate::predicate;
64
65        predicate![self <= value]
66    }
67
68    fn atomic_greater_than(&self, value: i32) -> Predicate {
69        use crate::predicate;
70
71        predicate![self >= value]
72    }
73
74    fn atomic_equal(&self, value: i32) -> Predicate {
75        use crate::predicate;
76
77        predicate![self == value]
78    }
79
80    fn atomic_not_equal(&self, value: i32) -> Predicate {
81        use crate::predicate;
82
83        predicate![self != value]
84    }
85
86    fn induced_lower_bound(
87        &self,
88        variable_state: &pumpkin_checking::VariableState<Predicate>,
89    ) -> IntExt {
90        if self.scale.is_positive() {
91            match self.inner.induced_lower_bound(variable_state) {
92                IntExt::Int(value) => IntExt::Int(self.map(value)),
93                bound => bound,
94            }
95        } else {
96            match self.inner.induced_upper_bound(variable_state) {
97                IntExt::Int(value) => IntExt::Int(self.map(value)),
98                IntExt::NegativeInf => IntExt::PositiveInf,
99                IntExt::PositiveInf => IntExt::NegativeInf,
100            }
101        }
102    }
103
104    fn induced_upper_bound(
105        &self,
106        variable_state: &pumpkin_checking::VariableState<Predicate>,
107    ) -> IntExt {
108        if self.scale.is_positive() {
109            match self.inner.induced_upper_bound(variable_state) {
110                IntExt::Int(value) => IntExt::Int(self.map(value)),
111                bound => bound,
112            }
113        } else {
114            match self.inner.induced_lower_bound(variable_state) {
115                IntExt::Int(value) => IntExt::Int(self.map(value)),
116                IntExt::NegativeInf => IntExt::PositiveInf,
117                IntExt::PositiveInf => IntExt::NegativeInf,
118            }
119        }
120    }
121
122    fn induced_fixed_value(
123        &self,
124        variable_state: &pumpkin_checking::VariableState<Predicate>,
125    ) -> Option<i32> {
126        self.inner
127            .induced_fixed_value(variable_state)
128            .map(|value| self.map(value))
129    }
130
131    fn induced_domain_contains(
132        &self,
133        variable_state: &pumpkin_checking::VariableState<Predicate>,
134        value: i32,
135    ) -> bool {
136        let translated_value = value - self.offset;
137
138        // If the translated value does not divide by scale, then the original value is not in the
139        // domain of this affine view.
140        if translated_value % self.scale != 0 {
141            return false;
142        }
143
144        let unscaled_value = translated_value / self.scale;
145
146        self.inner
147            .induced_domain_contains(variable_state, unscaled_value)
148    }
149
150    fn induced_holes<'this, 'state>(
151        &'this self,
152        variable_state: &'state pumpkin_checking::VariableState<Predicate>,
153    ) -> impl Iterator<Item = i32> + 'state
154    where
155        'this: 'state,
156    {
157        if self.scale == 1 || self.scale == -1 {
158            return self
159                .inner
160                .induced_holes(variable_state)
161                .map(|value| self.map(value));
162        }
163
164        todo!("how to iterate holes of a scaled domain");
165    }
166
167    fn iter_induced_domain<'this, 'state>(
168        &'this self,
169        variable_state: &'state pumpkin_checking::VariableState<Predicate>,
170    ) -> Option<impl Iterator<Item = i32> + 'state>
171    where
172        'this: 'state,
173    {
174        self.inner
175            .iter_induced_domain(variable_state)
176            .map(|iter| iter.map(|value| self.map(value)))
177    }
178}
179
180impl<View> IntegerVariable for AffineView<View>
181where
182    View: IntegerVariable,
183{
184    type AffineView = Self;
185
186    fn lower_bound(&self, assignment: &Assignments) -> i32 {
187        if self.scale < 0 {
188            self.map(self.inner.upper_bound(assignment))
189        } else {
190            self.map(self.inner.lower_bound(assignment))
191        }
192    }
193
194    fn lower_bound_at_trail_position(
195        &self,
196        assignment: &Assignments,
197        trail_position: usize,
198    ) -> i32 {
199        if self.scale < 0 {
200            self.map(
201                self.inner
202                    .upper_bound_at_trail_position(assignment, trail_position),
203            )
204        } else {
205            self.map(
206                self.inner
207                    .lower_bound_at_trail_position(assignment, trail_position),
208            )
209        }
210    }
211
212    fn upper_bound(&self, assignment: &Assignments) -> i32 {
213        if self.scale < 0 {
214            self.map(self.inner.lower_bound(assignment))
215        } else {
216            self.map(self.inner.upper_bound(assignment))
217        }
218    }
219
220    fn upper_bound_at_trail_position(
221        &self,
222        assignment: &Assignments,
223        trail_position: usize,
224    ) -> i32 {
225        if self.scale < 0 {
226            self.map(
227                self.inner
228                    .lower_bound_at_trail_position(assignment, trail_position),
229            )
230        } else {
231            self.map(
232                self.inner
233                    .upper_bound_at_trail_position(assignment, trail_position),
234            )
235        }
236    }
237
238    fn contains(&self, assignment: &Assignments, value: i32) -> bool {
239        if (value - self.offset) % self.scale == 0 {
240            let inverted = self.invert(value, Rounding::Up);
241            self.inner.contains(assignment, inverted)
242        } else {
243            false
244        }
245    }
246
247    fn contains_at_trail_position(
248        &self,
249        assignment: &Assignments,
250        value: i32,
251        trail_position: usize,
252    ) -> bool {
253        if (value - self.offset) % self.scale == 0 {
254            let inverted = self.invert(value, Rounding::Up);
255            self.inner
256                .contains_at_trail_position(assignment, inverted, trail_position)
257        } else {
258            false
259        }
260    }
261
262    fn iterate_domain(&self, assignment: &Assignments) -> impl Iterator<Item = i32> {
263        self.inner
264            .iterate_domain(assignment)
265            .map(|value| self.map(value))
266    }
267
268    fn watch_all(&self, watchers: &mut Watchers<'_>, mut events: EnumSet<DomainEvent>) {
269        let bound = DomainEvent::LowerBound | DomainEvent::UpperBound;
270        let intersection = events.intersection(bound);
271        if intersection.len() == 1 && self.scale.is_negative() {
272            events = events.symmetrical_difference(bound);
273        }
274        self.inner.watch_all(watchers, events);
275    }
276
277    fn unwatch_all(&self, watchers: &mut Watchers<'_>) {
278        self.inner.unwatch_all(watchers);
279    }
280
281    fn watch_all_backtrack(&self, watchers: &mut Watchers<'_>, mut events: EnumSet<DomainEvent>) {
282        let bound = DomainEvent::LowerBound | DomainEvent::UpperBound;
283        let intersection = events.intersection(bound);
284        if intersection.len() == 1 && self.scale.is_negative() {
285            events = events.symmetrical_difference(bound);
286        }
287        self.inner.watch_all_backtrack(watchers, events);
288    }
289
290    fn unpack_event(&self, event: OpaqueDomainEvent) -> DomainEvent {
291        if self.scale.is_negative() {
292            match self.inner.unpack_event(event) {
293                DomainEvent::LowerBound => DomainEvent::UpperBound,
294                DomainEvent::UpperBound => DomainEvent::LowerBound,
295                event => event,
296            }
297        } else {
298            self.inner.unpack_event(event)
299        }
300    }
301
302    fn get_holes_at_current_checkpoint(
303        &self,
304        assignments: &Assignments,
305    ) -> impl Iterator<Item = i32> {
306        self.inner
307            .get_holes_at_current_checkpoint(assignments)
308            .map(|value| self.map(value))
309    }
310
311    fn get_holes(&self, assignments: &Assignments) -> impl Iterator<Item = i32> {
312        self.inner
313            .get_holes(assignments)
314            .map(|value| self.map(value))
315    }
316}
317
318impl<View> TransformableVariable<AffineView<View>> for AffineView<View>
319where
320    View: IntegerVariable,
321{
322    fn scaled(&self, scale: i32) -> AffineView<View> {
323        let mut result = self.clone();
324        result.scale *= scale;
325        result.offset *= scale;
326        result
327    }
328
329    fn offset(&self, offset: i32) -> AffineView<View> {
330        let mut result = self.clone();
331        result.offset += offset;
332        result
333    }
334}
335
336impl<Var: std::fmt::Debug> std::fmt::Debug for AffineView<Var> {
337    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
338        if self.scale == -1 {
339            write!(f, "-")?;
340        } else if self.scale != 1 {
341            write!(f, "{} * ", self.scale)?;
342        }
343
344        write!(f, "({:?})", self.inner)?;
345
346        match self.offset.cmp(&0) {
347            Ordering::Less => write!(f, " - {}", -self.offset)?,
348            Ordering::Equal => {}
349            Ordering::Greater => write!(f, " + {}", self.offset)?,
350        }
351
352        Ok(())
353    }
354}
355
356impl<Var: PredicateConstructor<Value = i32>> PredicateConstructor for AffineView<Var> {
357    type Value = Var::Value;
358
359    fn lower_bound_predicate(&self, bound: Self::Value) -> Predicate {
360        if self.scale < 0 {
361            let inverted_bound = self.invert(bound, Rounding::Down);
362            self.inner.upper_bound_predicate(inverted_bound)
363        } else {
364            let inverted_bound = self.invert(bound, Rounding::Up);
365            self.inner.lower_bound_predicate(inverted_bound)
366        }
367    }
368
369    fn upper_bound_predicate(&self, bound: Self::Value) -> Predicate {
370        if self.scale < 0 {
371            let inverted_bound = self.invert(bound, Rounding::Up);
372            self.inner.lower_bound_predicate(inverted_bound)
373        } else {
374            let inverted_bound = self.invert(bound, Rounding::Down);
375            self.inner.upper_bound_predicate(inverted_bound)
376        }
377    }
378
379    fn equality_predicate(&self, bound: Self::Value) -> Predicate {
380        if (bound - self.offset) % self.scale == 0 {
381            let inverted_bound = self.invert(bound, Rounding::Up);
382            self.inner.equality_predicate(inverted_bound)
383        } else {
384            Predicate::trivially_false()
385        }
386    }
387
388    fn disequality_predicate(&self, bound: Self::Value) -> Predicate {
389        if (bound - self.offset) % self.scale == 0 {
390            let inverted_bound = self.invert(bound, Rounding::Up);
391            self.inner.disequality_predicate(inverted_bound)
392        } else {
393            Predicate::trivially_true()
394        }
395    }
396}
397
398impl From<DomainId> for AffineView<DomainId> {
399    fn from(value: DomainId) -> Self {
400        AffineView::new(value, 1, 0)
401    }
402}
403
404enum Rounding {
405    Up,
406    Down,
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412    use crate::predicate;
413
414    #[test]
415    fn scaling_an_affine_view() {
416        let view = AffineView::new(DomainId::new(0), 3, 4);
417        assert_eq!(3, view.scale);
418        assert_eq!(4, view.offset);
419        let scaled_view = view.scaled(6);
420        assert_eq!(18, scaled_view.scale);
421        assert_eq!(24, scaled_view.offset);
422    }
423
424    #[test]
425    fn offsetting_an_affine_view() {
426        let view = AffineView::new(DomainId::new(0), 3, 4);
427        assert_eq!(3, view.scale);
428        assert_eq!(4, view.offset);
429        let scaled_view = view.offset(6);
430        assert_eq!(3, scaled_view.scale);
431        assert_eq!(10, scaled_view.offset);
432    }
433
434    #[test]
435    fn affine_view_obtaining_a_bound_should_round_optimistically_in_inner_domain() {
436        let domain = DomainId::new(0);
437        let view = AffineView::new(domain, 2, 0);
438
439        assert_eq!(predicate!(domain >= 1), predicate!(view >= 1));
440        assert_eq!(predicate!(domain >= -1), predicate!(view >= -3));
441        assert_eq!(predicate!(domain <= 0), predicate!(view <= 1));
442        assert_eq!(predicate!(domain <= -3), predicate!(view <= -5));
443    }
444
445    #[test]
446    fn test_negated_variable_has_bounds_rounded_correctly() {
447        let domain = DomainId::new(0);
448        let view = AffineView::new(domain, -2, 0);
449
450        assert_eq!(predicate!(view <= -3), predicate!(domain >= 2));
451        assert_eq!(predicate!(view >= 5), predicate!(domain <= -3));
452    }
453}