pumpkin_core/engine/variables/
affine_view.rs1use std::cmp::Ordering;
2
3use enumset::EnumSet;
4use pumpkin_checking::CheckerVariable;
5use pumpkin_checking::IntExt;
6
7use super::TransformableVariable;
8use crate::engine::Assignments;
9use crate::engine::notifications::DomainEvent;
10use crate::engine::notifications::OpaqueDomainEvent;
11use crate::engine::notifications::Watchers;
12use crate::engine::predicates::predicate::Predicate;
13use crate::engine::predicates::predicate_constructor::PredicateConstructor;
14use crate::engine::variables::DomainId;
15use crate::engine::variables::IntegerVariable;
16use crate::math::num_ext::NumExt;
17
18#[derive(Clone, Copy, Hash, Eq, PartialEq)]
21pub struct AffineView<Inner> {
22 pub(crate) inner: Inner,
23 pub(crate) scale: i32,
24 pub(crate) offset: i32,
25}
26
27impl<Inner> AffineView<Inner> {
28 pub fn new(inner: Inner, scale: i32, offset: i32) -> Self {
29 assert_ne!(scale, 0, "Multiplication by zero is not invertable");
30 AffineView {
31 inner,
32 scale,
33 offset,
34 }
35 }
36
37 pub fn inner(&self) -> &Inner {
38 &self.inner
39 }
40
41 fn invert(&self, value: i32, rounding: Rounding) -> i32 {
44 let inverted_translation = value - self.offset;
45
46 match rounding {
47 Rounding::Up => <i32 as NumExt>::div_ceil(inverted_translation, self.scale),
48 Rounding::Down => <i32 as NumExt>::div_floor(inverted_translation, self.scale),
49 }
50 }
51
52 fn map(&self, value: i32) -> i32 {
53 self.scale * value + self.offset
54 }
55}
56
57impl<Var: IntegerVariable> CheckerVariable<Predicate> for AffineView<Var> {
58 fn does_atomic_constrain_self(&self, atomic: &Predicate) -> bool {
59 self.inner.does_atomic_constrain_self(atomic)
60 }
61
62 fn atomic_less_than(&self, value: i32) -> Predicate {
63 use crate::predicate;
64
65 predicate![self <= value]
66 }
67
68 fn atomic_greater_than(&self, value: i32) -> Predicate {
69 use crate::predicate;
70
71 predicate![self >= value]
72 }
73
74 fn atomic_equal(&self, value: i32) -> Predicate {
75 use crate::predicate;
76
77 predicate![self == value]
78 }
79
80 fn atomic_not_equal(&self, value: i32) -> Predicate {
81 use crate::predicate;
82
83 predicate![self != value]
84 }
85
86 fn induced_lower_bound(
87 &self,
88 variable_state: &pumpkin_checking::VariableState<Predicate>,
89 ) -> IntExt {
90 if self.scale.is_positive() {
91 match self.inner.induced_lower_bound(variable_state) {
92 IntExt::Int(value) => IntExt::Int(self.map(value)),
93 bound => bound,
94 }
95 } else {
96 match self.inner.induced_upper_bound(variable_state) {
97 IntExt::Int(value) => IntExt::Int(self.map(value)),
98 IntExt::NegativeInf => IntExt::PositiveInf,
99 IntExt::PositiveInf => IntExt::NegativeInf,
100 }
101 }
102 }
103
104 fn induced_upper_bound(
105 &self,
106 variable_state: &pumpkin_checking::VariableState<Predicate>,
107 ) -> IntExt {
108 if self.scale.is_positive() {
109 match self.inner.induced_upper_bound(variable_state) {
110 IntExt::Int(value) => IntExt::Int(self.map(value)),
111 bound => bound,
112 }
113 } else {
114 match self.inner.induced_lower_bound(variable_state) {
115 IntExt::Int(value) => IntExt::Int(self.map(value)),
116 IntExt::NegativeInf => IntExt::PositiveInf,
117 IntExt::PositiveInf => IntExt::NegativeInf,
118 }
119 }
120 }
121
122 fn induced_fixed_value(
123 &self,
124 variable_state: &pumpkin_checking::VariableState<Predicate>,
125 ) -> Option<i32> {
126 self.inner
127 .induced_fixed_value(variable_state)
128 .map(|value| self.map(value))
129 }
130
131 fn induced_domain_contains(
132 &self,
133 variable_state: &pumpkin_checking::VariableState<Predicate>,
134 value: i32,
135 ) -> bool {
136 let translated_value = value - self.offset;
137
138 if translated_value % self.scale != 0 {
141 return false;
142 }
143
144 let unscaled_value = translated_value / self.scale;
145
146 self.inner
147 .induced_domain_contains(variable_state, unscaled_value)
148 }
149
150 fn induced_holes<'this, 'state>(
151 &'this self,
152 variable_state: &'state pumpkin_checking::VariableState<Predicate>,
153 ) -> impl Iterator<Item = i32> + 'state
154 where
155 'this: 'state,
156 {
157 if self.scale == 1 || self.scale == -1 {
158 return self
159 .inner
160 .induced_holes(variable_state)
161 .map(|value| self.map(value));
162 }
163
164 todo!("how to iterate holes of a scaled domain");
165 }
166
167 fn iter_induced_domain<'this, 'state>(
168 &'this self,
169 variable_state: &'state pumpkin_checking::VariableState<Predicate>,
170 ) -> Option<impl Iterator<Item = i32> + 'state>
171 where
172 'this: 'state,
173 {
174 self.inner
175 .iter_induced_domain(variable_state)
176 .map(|iter| iter.map(|value| self.map(value)))
177 }
178}
179
180impl<View> IntegerVariable for AffineView<View>
181where
182 View: IntegerVariable,
183{
184 type AffineView = Self;
185
186 fn lower_bound(&self, assignment: &Assignments) -> i32 {
187 if self.scale < 0 {
188 self.map(self.inner.upper_bound(assignment))
189 } else {
190 self.map(self.inner.lower_bound(assignment))
191 }
192 }
193
194 fn lower_bound_at_trail_position(
195 &self,
196 assignment: &Assignments,
197 trail_position: usize,
198 ) -> i32 {
199 if self.scale < 0 {
200 self.map(
201 self.inner
202 .upper_bound_at_trail_position(assignment, trail_position),
203 )
204 } else {
205 self.map(
206 self.inner
207 .lower_bound_at_trail_position(assignment, trail_position),
208 )
209 }
210 }
211
212 fn upper_bound(&self, assignment: &Assignments) -> i32 {
213 if self.scale < 0 {
214 self.map(self.inner.lower_bound(assignment))
215 } else {
216 self.map(self.inner.upper_bound(assignment))
217 }
218 }
219
220 fn upper_bound_at_trail_position(
221 &self,
222 assignment: &Assignments,
223 trail_position: usize,
224 ) -> i32 {
225 if self.scale < 0 {
226 self.map(
227 self.inner
228 .lower_bound_at_trail_position(assignment, trail_position),
229 )
230 } else {
231 self.map(
232 self.inner
233 .upper_bound_at_trail_position(assignment, trail_position),
234 )
235 }
236 }
237
238 fn contains(&self, assignment: &Assignments, value: i32) -> bool {
239 if (value - self.offset) % self.scale == 0 {
240 let inverted = self.invert(value, Rounding::Up);
241 self.inner.contains(assignment, inverted)
242 } else {
243 false
244 }
245 }
246
247 fn contains_at_trail_position(
248 &self,
249 assignment: &Assignments,
250 value: i32,
251 trail_position: usize,
252 ) -> bool {
253 if (value - self.offset) % self.scale == 0 {
254 let inverted = self.invert(value, Rounding::Up);
255 self.inner
256 .contains_at_trail_position(assignment, inverted, trail_position)
257 } else {
258 false
259 }
260 }
261
262 fn iterate_domain(&self, assignment: &Assignments) -> impl Iterator<Item = i32> {
263 self.inner
264 .iterate_domain(assignment)
265 .map(|value| self.map(value))
266 }
267
268 fn watch_all(&self, watchers: &mut Watchers<'_>, mut events: EnumSet<DomainEvent>) {
269 let bound = DomainEvent::LowerBound | DomainEvent::UpperBound;
270 let intersection = events.intersection(bound);
271 if intersection.len() == 1 && self.scale.is_negative() {
272 events = events.symmetrical_difference(bound);
273 }
274 self.inner.watch_all(watchers, events);
275 }
276
277 fn unwatch_all(&self, watchers: &mut Watchers<'_>) {
278 self.inner.unwatch_all(watchers);
279 }
280
281 fn watch_all_backtrack(&self, watchers: &mut Watchers<'_>, mut events: EnumSet<DomainEvent>) {
282 let bound = DomainEvent::LowerBound | DomainEvent::UpperBound;
283 let intersection = events.intersection(bound);
284 if intersection.len() == 1 && self.scale.is_negative() {
285 events = events.symmetrical_difference(bound);
286 }
287 self.inner.watch_all_backtrack(watchers, events);
288 }
289
290 fn unpack_event(&self, event: OpaqueDomainEvent) -> DomainEvent {
291 if self.scale.is_negative() {
292 match self.inner.unpack_event(event) {
293 DomainEvent::LowerBound => DomainEvent::UpperBound,
294 DomainEvent::UpperBound => DomainEvent::LowerBound,
295 event => event,
296 }
297 } else {
298 self.inner.unpack_event(event)
299 }
300 }
301
302 fn get_holes_at_current_checkpoint(
303 &self,
304 assignments: &Assignments,
305 ) -> impl Iterator<Item = i32> {
306 self.inner
307 .get_holes_at_current_checkpoint(assignments)
308 .map(|value| self.map(value))
309 }
310
311 fn get_holes(&self, assignments: &Assignments) -> impl Iterator<Item = i32> {
312 self.inner
313 .get_holes(assignments)
314 .map(|value| self.map(value))
315 }
316}
317
318impl<View> TransformableVariable<AffineView<View>> for AffineView<View>
319where
320 View: IntegerVariable,
321{
322 fn scaled(&self, scale: i32) -> AffineView<View> {
323 let mut result = self.clone();
324 result.scale *= scale;
325 result.offset *= scale;
326 result
327 }
328
329 fn offset(&self, offset: i32) -> AffineView<View> {
330 let mut result = self.clone();
331 result.offset += offset;
332 result
333 }
334}
335
336impl<Var: std::fmt::Debug> std::fmt::Debug for AffineView<Var> {
337 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
338 if self.scale == -1 {
339 write!(f, "-")?;
340 } else if self.scale != 1 {
341 write!(f, "{} * ", self.scale)?;
342 }
343
344 write!(f, "({:?})", self.inner)?;
345
346 match self.offset.cmp(&0) {
347 Ordering::Less => write!(f, " - {}", -self.offset)?,
348 Ordering::Equal => {}
349 Ordering::Greater => write!(f, " + {}", self.offset)?,
350 }
351
352 Ok(())
353 }
354}
355
356impl<Var: PredicateConstructor<Value = i32>> PredicateConstructor for AffineView<Var> {
357 type Value = Var::Value;
358
359 fn lower_bound_predicate(&self, bound: Self::Value) -> Predicate {
360 if self.scale < 0 {
361 let inverted_bound = self.invert(bound, Rounding::Down);
362 self.inner.upper_bound_predicate(inverted_bound)
363 } else {
364 let inverted_bound = self.invert(bound, Rounding::Up);
365 self.inner.lower_bound_predicate(inverted_bound)
366 }
367 }
368
369 fn upper_bound_predicate(&self, bound: Self::Value) -> Predicate {
370 if self.scale < 0 {
371 let inverted_bound = self.invert(bound, Rounding::Up);
372 self.inner.lower_bound_predicate(inverted_bound)
373 } else {
374 let inverted_bound = self.invert(bound, Rounding::Down);
375 self.inner.upper_bound_predicate(inverted_bound)
376 }
377 }
378
379 fn equality_predicate(&self, bound: Self::Value) -> Predicate {
380 if (bound - self.offset) % self.scale == 0 {
381 let inverted_bound = self.invert(bound, Rounding::Up);
382 self.inner.equality_predicate(inverted_bound)
383 } else {
384 Predicate::trivially_false()
385 }
386 }
387
388 fn disequality_predicate(&self, bound: Self::Value) -> Predicate {
389 if (bound - self.offset) % self.scale == 0 {
390 let inverted_bound = self.invert(bound, Rounding::Up);
391 self.inner.disequality_predicate(inverted_bound)
392 } else {
393 Predicate::trivially_true()
394 }
395 }
396}
397
398impl From<DomainId> for AffineView<DomainId> {
399 fn from(value: DomainId) -> Self {
400 AffineView::new(value, 1, 0)
401 }
402}
403
404enum Rounding {
405 Up,
406 Down,
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412 use crate::predicate;
413
414 #[test]
415 fn scaling_an_affine_view() {
416 let view = AffineView::new(DomainId::new(0), 3, 4);
417 assert_eq!(3, view.scale);
418 assert_eq!(4, view.offset);
419 let scaled_view = view.scaled(6);
420 assert_eq!(18, scaled_view.scale);
421 assert_eq!(24, scaled_view.offset);
422 }
423
424 #[test]
425 fn offsetting_an_affine_view() {
426 let view = AffineView::new(DomainId::new(0), 3, 4);
427 assert_eq!(3, view.scale);
428 assert_eq!(4, view.offset);
429 let scaled_view = view.offset(6);
430 assert_eq!(3, scaled_view.scale);
431 assert_eq!(10, scaled_view.offset);
432 }
433
434 #[test]
435 fn affine_view_obtaining_a_bound_should_round_optimistically_in_inner_domain() {
436 let domain = DomainId::new(0);
437 let view = AffineView::new(domain, 2, 0);
438
439 assert_eq!(predicate!(domain >= 1), predicate!(view >= 1));
440 assert_eq!(predicate!(domain >= -1), predicate!(view >= -3));
441 assert_eq!(predicate!(domain <= 0), predicate!(view <= 1));
442 assert_eq!(predicate!(domain <= -3), predicate!(view <= -5));
443 }
444
445 #[test]
446 fn test_negated_variable_has_bounds_rounded_correctly() {
447 let domain = DomainId::new(0);
448 let view = AffineView::new(domain, -2, 0);
449
450 assert_eq!(predicate!(view <= -3), predicate!(domain >= 2));
451 assert_eq!(predicate!(view >= 5), predicate!(domain <= -3));
452 }
453}