pumpkin_core/engine/variables/
affine_view.rs1use std::cmp::Ordering;
2
3use enumset::EnumSet;
4
5use super::TransformableVariable;
6use crate::engine::notifications::DomainEvent;
7use crate::engine::notifications::OpaqueDomainEvent;
8use crate::engine::notifications::Watchers;
9use crate::engine::predicates::predicate::Predicate;
10use crate::engine::predicates::predicate_constructor::PredicateConstructor;
11use crate::engine::variables::DomainId;
12use crate::engine::variables::IntegerVariable;
13use crate::engine::Assignments;
14use crate::math::num_ext::NumExt;
15
16#[derive(Clone, Copy, Hash, Eq, PartialEq)]
19pub struct AffineView<Inner> {
20 inner: Inner,
21 scale: i32,
22 offset: i32,
23}
24
25impl<Inner> AffineView<Inner> {
26 pub fn new(inner: Inner, scale: i32, offset: i32) -> Self {
27 assert_ne!(scale, 0, "Multiplication by zero is not invertable");
28 AffineView {
29 inner,
30 scale,
31 offset,
32 }
33 }
34
35 fn invert(&self, value: i32, rounding: Rounding) -> i32 {
38 let inverted_translation = value - self.offset;
39
40 match rounding {
41 Rounding::Up => <i32 as NumExt>::div_ceil(inverted_translation, self.scale),
42 Rounding::Down => <i32 as NumExt>::div_floor(inverted_translation, self.scale),
43 }
44 }
45
46 fn map(&self, value: i32) -> i32 {
47 self.scale * value + self.offset
48 }
49}
50
51impl<View> IntegerVariable for AffineView<View>
52where
53 View: IntegerVariable,
54{
55 type AffineView = Self;
56
57 fn lower_bound(&self, assignment: &Assignments) -> i32 {
58 if self.scale < 0 {
59 self.map(self.inner.upper_bound(assignment))
60 } else {
61 self.map(self.inner.lower_bound(assignment))
62 }
63 }
64
65 fn lower_bound_at_trail_position(
66 &self,
67 assignment: &Assignments,
68 trail_position: usize,
69 ) -> i32 {
70 if self.scale < 0 {
71 self.map(
72 self.inner
73 .upper_bound_at_trail_position(assignment, trail_position),
74 )
75 } else {
76 self.map(
77 self.inner
78 .lower_bound_at_trail_position(assignment, trail_position),
79 )
80 }
81 }
82
83 fn upper_bound(&self, assignment: &Assignments) -> i32 {
84 if self.scale < 0 {
85 self.map(self.inner.lower_bound(assignment))
86 } else {
87 self.map(self.inner.upper_bound(assignment))
88 }
89 }
90
91 fn upper_bound_at_trail_position(
92 &self,
93 assignment: &Assignments,
94 trail_position: usize,
95 ) -> i32 {
96 if self.scale < 0 {
97 self.map(
98 self.inner
99 .lower_bound_at_trail_position(assignment, trail_position),
100 )
101 } else {
102 self.map(
103 self.inner
104 .upper_bound_at_trail_position(assignment, trail_position),
105 )
106 }
107 }
108
109 fn contains(&self, assignment: &Assignments, value: i32) -> bool {
110 if (value - self.offset) % self.scale == 0 {
111 let inverted = self.invert(value, Rounding::Up);
112 self.inner.contains(assignment, inverted)
113 } else {
114 false
115 }
116 }
117
118 fn contains_at_trail_position(
119 &self,
120 assignment: &Assignments,
121 value: i32,
122 trail_position: usize,
123 ) -> bool {
124 if (value - self.offset) % self.scale == 0 {
125 let inverted = self.invert(value, Rounding::Up);
126 self.inner
127 .contains_at_trail_position(assignment, inverted, trail_position)
128 } else {
129 false
130 }
131 }
132
133 fn iterate_domain(&self, assignment: &Assignments) -> impl Iterator<Item = i32> {
134 self.inner
135 .iterate_domain(assignment)
136 .map(|value| self.map(value))
137 }
138
139 fn watch_all(&self, watchers: &mut Watchers<'_>, mut events: EnumSet<DomainEvent>) {
140 let bound = DomainEvent::LowerBound | DomainEvent::UpperBound;
141 let intersection = events.intersection(bound);
142 if intersection.len() == 1 && self.scale.is_negative() {
143 events = events.symmetrical_difference(bound);
144 }
145 self.inner.watch_all(watchers, events);
146 }
147
148 fn watch_all_backtrack(&self, watchers: &mut Watchers<'_>, mut events: EnumSet<DomainEvent>) {
149 let bound = DomainEvent::LowerBound | DomainEvent::UpperBound;
150 let intersection = events.intersection(bound);
151 if intersection.len() == 1 && self.scale.is_negative() {
152 events = events.symmetrical_difference(bound);
153 }
154 self.inner.watch_all_backtrack(watchers, events);
155 }
156
157 fn unpack_event(&self, event: OpaqueDomainEvent) -> DomainEvent {
158 if self.scale.is_negative() {
159 match self.inner.unpack_event(event) {
160 DomainEvent::LowerBound => DomainEvent::UpperBound,
161 DomainEvent::UpperBound => DomainEvent::LowerBound,
162 event => event,
163 }
164 } else {
165 self.inner.unpack_event(event)
166 }
167 }
168
169 fn get_holes_on_current_decision_level(
170 &self,
171 assignments: &Assignments,
172 ) -> impl Iterator<Item = i32> {
173 self.inner
174 .get_holes_on_current_decision_level(assignments)
175 .map(|value| self.map(value))
176 }
177
178 fn get_holes(&self, assignments: &Assignments) -> impl Iterator<Item = i32> {
179 self.inner
180 .get_holes(assignments)
181 .map(|value| self.map(value))
182 }
183}
184
185impl<View> TransformableVariable<AffineView<View>> for AffineView<View>
186where
187 View: IntegerVariable,
188{
189 fn scaled(&self, scale: i32) -> AffineView<View> {
190 let mut result = self.clone();
191 result.scale *= scale;
192 result.offset *= scale;
193 result
194 }
195
196 fn offset(&self, offset: i32) -> AffineView<View> {
197 let mut result = self.clone();
198 result.offset += offset;
199 result
200 }
201}
202
203impl<Var: std::fmt::Debug> std::fmt::Debug for AffineView<Var> {
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 if self.scale == -1 {
206 write!(f, "-")?;
207 } else if self.scale != 1 {
208 write!(f, "{} * ", self.scale)?;
209 }
210
211 write!(f, "({:?})", self.inner)?;
212
213 match self.offset.cmp(&0) {
214 Ordering::Less => write!(f, " - {}", -self.offset)?,
215 Ordering::Equal => {}
216 Ordering::Greater => write!(f, " + {}", self.offset)?,
217 }
218
219 Ok(())
220 }
221}
222
223impl<Var: PredicateConstructor<Value = i32>> PredicateConstructor for AffineView<Var> {
224 type Value = Var::Value;
225
226 fn lower_bound_predicate(&self, bound: Self::Value) -> Predicate {
227 if self.scale < 0 {
228 let inverted_bound = self.invert(bound, Rounding::Down);
229 self.inner.upper_bound_predicate(inverted_bound)
230 } else {
231 let inverted_bound = self.invert(bound, Rounding::Up);
232 self.inner.lower_bound_predicate(inverted_bound)
233 }
234 }
235
236 fn upper_bound_predicate(&self, bound: Self::Value) -> Predicate {
237 if self.scale < 0 {
238 let inverted_bound = self.invert(bound, Rounding::Up);
239 self.inner.lower_bound_predicate(inverted_bound)
240 } else {
241 let inverted_bound = self.invert(bound, Rounding::Down);
242 self.inner.upper_bound_predicate(inverted_bound)
243 }
244 }
245
246 fn equality_predicate(&self, bound: Self::Value) -> Predicate {
247 if (bound - self.offset) % self.scale == 0 {
248 let inverted_bound = self.invert(bound, Rounding::Up);
249 self.inner.equality_predicate(inverted_bound)
250 } else {
251 Predicate::trivially_false()
252 }
253 }
254
255 fn disequality_predicate(&self, bound: Self::Value) -> Predicate {
256 if (bound - self.offset) % self.scale == 0 {
257 let inverted_bound = self.invert(bound, Rounding::Up);
258 self.inner.disequality_predicate(inverted_bound)
259 } else {
260 Predicate::trivially_true()
261 }
262 }
263}
264
265impl From<DomainId> for AffineView<DomainId> {
266 fn from(value: DomainId) -> Self {
267 AffineView::new(value, 1, 0)
268 }
269}
270
271enum Rounding {
272 Up,
273 Down,
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use crate::predicate;
280
281 #[test]
282 fn scaling_an_affine_view() {
283 let view = AffineView::new(DomainId::new(0), 3, 4);
284 assert_eq!(3, view.scale);
285 assert_eq!(4, view.offset);
286 let scaled_view = view.scaled(6);
287 assert_eq!(18, scaled_view.scale);
288 assert_eq!(24, scaled_view.offset);
289 }
290
291 #[test]
292 fn offsetting_an_affine_view() {
293 let view = AffineView::new(DomainId::new(0), 3, 4);
294 assert_eq!(3, view.scale);
295 assert_eq!(4, view.offset);
296 let scaled_view = view.offset(6);
297 assert_eq!(3, scaled_view.scale);
298 assert_eq!(10, scaled_view.offset);
299 }
300
301 #[test]
302 fn affine_view_obtaining_a_bound_should_round_optimistically_in_inner_domain() {
303 let domain = DomainId::new(0);
304 let view = AffineView::new(domain, 2, 0);
305
306 assert_eq!(predicate!(domain >= 1), predicate!(view >= 1));
307 assert_eq!(predicate!(domain >= -1), predicate!(view >= -3));
308 assert_eq!(predicate!(domain <= 0), predicate!(view <= 1));
309 assert_eq!(predicate!(domain <= -3), predicate!(view <= -5));
310 }
311
312 #[test]
313 fn test_negated_variable_has_bounds_rounded_correctly() {
314 let domain = DomainId::new(0);
315 let view = AffineView::new(domain, -2, 0);
316
317 assert_eq!(predicate!(view <= -3), predicate!(domain >= 2));
318 assert_eq!(predicate!(view >= 5), predicate!(domain <= -3));
319 }
320}