pumpkin_core/engine/variables/
affine_view.rs1use std::cmp::Ordering;
2
3use enumset::EnumSet;
4
5use super::TransformableVariable;
6use crate::engine::notifications::DomainEvent;
7use crate::engine::notifications::OpaqueDomainEvent;
8use crate::engine::notifications::Watchers;
9use crate::engine::predicates::predicate::Predicate;
10use crate::engine::predicates::predicate_constructor::PredicateConstructor;
11use crate::engine::variables::DomainId;
12use crate::engine::variables::IntegerVariable;
13use crate::engine::Assignments;
14use crate::math::num_ext::NumExt;
15
16#[derive(Clone, Copy, Hash, Eq, PartialEq)]
19pub struct AffineView<Inner> {
20 inner: Inner,
21 scale: i32,
22 offset: i32,
23}
24
25impl<Inner> AffineView<Inner> {
26 pub fn new(inner: Inner, scale: i32, offset: i32) -> Self {
27 AffineView {
28 inner,
29 scale,
30 offset,
31 }
32 }
33
34 fn invert(&self, value: i32, rounding: Rounding) -> i32 {
37 let inverted_translation = value - self.offset;
38
39 match rounding {
40 Rounding::Up => <i32 as NumExt>::div_ceil(inverted_translation, self.scale),
41 Rounding::Down => <i32 as NumExt>::div_floor(inverted_translation, self.scale),
42 }
43 }
44
45 fn map(&self, value: i32) -> i32 {
46 self.scale * value + self.offset
47 }
48}
49
50impl<View> IntegerVariable for AffineView<View>
51where
52 View: IntegerVariable,
53{
54 type AffineView = Self;
55
56 fn lower_bound(&self, assignment: &Assignments) -> i32 {
57 if self.scale < 0 {
58 self.map(self.inner.upper_bound(assignment))
59 } else {
60 self.map(self.inner.lower_bound(assignment))
61 }
62 }
63
64 fn lower_bound_at_trail_position(
65 &self,
66 assignment: &Assignments,
67 trail_position: usize,
68 ) -> i32 {
69 if self.scale < 0 {
70 self.map(
71 self.inner
72 .upper_bound_at_trail_position(assignment, trail_position),
73 )
74 } else {
75 self.map(
76 self.inner
77 .lower_bound_at_trail_position(assignment, trail_position),
78 )
79 }
80 }
81
82 fn upper_bound(&self, assignment: &Assignments) -> i32 {
83 if self.scale < 0 {
84 self.map(self.inner.lower_bound(assignment))
85 } else {
86 self.map(self.inner.upper_bound(assignment))
87 }
88 }
89
90 fn upper_bound_at_trail_position(
91 &self,
92 assignment: &Assignments,
93 trail_position: usize,
94 ) -> i32 {
95 if self.scale < 0 {
96 self.map(
97 self.inner
98 .lower_bound_at_trail_position(assignment, trail_position),
99 )
100 } else {
101 self.map(
102 self.inner
103 .upper_bound_at_trail_position(assignment, trail_position),
104 )
105 }
106 }
107
108 fn contains(&self, assignment: &Assignments, value: i32) -> bool {
109 if (value - self.offset) % self.scale == 0 {
110 let inverted = self.invert(value, Rounding::Up);
111 self.inner.contains(assignment, inverted)
112 } else {
113 false
114 }
115 }
116
117 fn contains_at_trail_position(
118 &self,
119 assignment: &Assignments,
120 value: i32,
121 trail_position: usize,
122 ) -> bool {
123 if (value - self.offset) % self.scale == 0 {
124 let inverted = self.invert(value, Rounding::Up);
125 self.inner
126 .contains_at_trail_position(assignment, inverted, trail_position)
127 } else {
128 false
129 }
130 }
131
132 fn iterate_domain(&self, assignment: &Assignments) -> impl Iterator<Item = i32> {
133 self.inner
134 .iterate_domain(assignment)
135 .map(|value| self.map(value))
136 }
137
138 fn watch_all(&self, watchers: &mut Watchers<'_>, mut events: EnumSet<DomainEvent>) {
139 let bound = DomainEvent::LowerBound | DomainEvent::UpperBound;
140 let intersection = events.intersection(bound);
141 if intersection.len() == 1 && self.scale.is_negative() {
142 events = events.symmetrical_difference(bound);
143 }
144 self.inner.watch_all(watchers, events);
145 }
146
147 fn watch_all_backtrack(&self, watchers: &mut Watchers<'_>, mut events: EnumSet<DomainEvent>) {
148 let bound = DomainEvent::LowerBound | DomainEvent::UpperBound;
149 let intersection = events.intersection(bound);
150 if intersection.len() == 1 && self.scale.is_negative() {
151 events = events.symmetrical_difference(bound);
152 }
153 self.inner.watch_all_backtrack(watchers, events);
154 }
155
156 fn unpack_event(&self, event: OpaqueDomainEvent) -> DomainEvent {
157 if self.scale.is_negative() {
158 match self.inner.unpack_event(event) {
159 DomainEvent::LowerBound => DomainEvent::UpperBound,
160 DomainEvent::UpperBound => DomainEvent::LowerBound,
161 event => event,
162 }
163 } else {
164 self.inner.unpack_event(event)
165 }
166 }
167
168 fn get_holes_on_current_decision_level(
169 &self,
170 assignments: &Assignments,
171 ) -> impl Iterator<Item = i32> {
172 self.inner
173 .get_holes_on_current_decision_level(assignments)
174 .map(|value| self.map(value))
175 }
176
177 fn get_holes(&self, assignments: &Assignments) -> impl Iterator<Item = i32> {
178 self.inner
179 .get_holes(assignments)
180 .map(|value| self.map(value))
181 }
182}
183
184impl<View> TransformableVariable<AffineView<View>> for AffineView<View>
185where
186 View: IntegerVariable,
187{
188 fn scaled(&self, scale: i32) -> AffineView<View> {
189 let mut result = self.clone();
190 result.scale *= scale;
191 result.offset *= scale;
192 result
193 }
194
195 fn offset(&self, offset: i32) -> AffineView<View> {
196 let mut result = self.clone();
197 result.offset += offset;
198 result
199 }
200}
201
202impl<Var: std::fmt::Debug> std::fmt::Debug for AffineView<Var> {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 if self.scale == -1 {
205 write!(f, "-")?;
206 } else if self.scale != 1 {
207 write!(f, "{} * ", self.scale)?;
208 }
209
210 write!(f, "({:?})", self.inner)?;
211
212 match self.offset.cmp(&0) {
213 Ordering::Less => write!(f, " - {}", -self.offset)?,
214 Ordering::Equal => {}
215 Ordering::Greater => write!(f, " + {}", self.offset)?,
216 }
217
218 Ok(())
219 }
220}
221
222impl<Var: PredicateConstructor<Value = i32>> PredicateConstructor for AffineView<Var> {
223 type Value = Var::Value;
224
225 fn lower_bound_predicate(&self, bound: Self::Value) -> Predicate {
226 if self.scale < 0 {
227 let inverted_bound = self.invert(bound, Rounding::Down);
228 self.inner.upper_bound_predicate(inverted_bound)
229 } else {
230 let inverted_bound = self.invert(bound, Rounding::Up);
231 self.inner.lower_bound_predicate(inverted_bound)
232 }
233 }
234
235 fn upper_bound_predicate(&self, bound: Self::Value) -> Predicate {
236 if self.scale < 0 {
237 let inverted_bound = self.invert(bound, Rounding::Up);
238 self.inner.lower_bound_predicate(inverted_bound)
239 } else {
240 let inverted_bound = self.invert(bound, Rounding::Down);
241 self.inner.upper_bound_predicate(inverted_bound)
242 }
243 }
244
245 fn equality_predicate(&self, bound: Self::Value) -> Predicate {
246 if (bound - self.offset) % self.scale == 0 {
247 let inverted_bound = self.invert(bound, Rounding::Up);
248 self.inner.equality_predicate(inverted_bound)
249 } else {
250 Predicate::trivially_false()
251 }
252 }
253
254 fn disequality_predicate(&self, bound: Self::Value) -> Predicate {
255 if (bound - self.offset) % self.scale == 0 {
256 let inverted_bound = self.invert(bound, Rounding::Up);
257 self.inner.disequality_predicate(inverted_bound)
258 } else {
259 Predicate::trivially_true()
260 }
261 }
262}
263
264impl From<DomainId> for AffineView<DomainId> {
265 fn from(value: DomainId) -> Self {
266 AffineView::new(value, 1, 0)
267 }
268}
269
270enum Rounding {
271 Up,
272 Down,
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278 use crate::predicate;
279
280 #[test]
281 fn scaling_an_affine_view() {
282 let view = AffineView::new(DomainId::new(0), 3, 4);
283 assert_eq!(3, view.scale);
284 assert_eq!(4, view.offset);
285 let scaled_view = view.scaled(6);
286 assert_eq!(18, scaled_view.scale);
287 assert_eq!(24, scaled_view.offset);
288 }
289
290 #[test]
291 fn offsetting_an_affine_view() {
292 let view = AffineView::new(DomainId::new(0), 3, 4);
293 assert_eq!(3, view.scale);
294 assert_eq!(4, view.offset);
295 let scaled_view = view.offset(6);
296 assert_eq!(3, scaled_view.scale);
297 assert_eq!(10, scaled_view.offset);
298 }
299
300 #[test]
301 fn affine_view_obtaining_a_bound_should_round_optimistically_in_inner_domain() {
302 let domain = DomainId::new(0);
303 let view = AffineView::new(domain, 2, 0);
304
305 assert_eq!(predicate!(domain >= 1), predicate!(view >= 1));
306 assert_eq!(predicate!(domain >= -1), predicate!(view >= -3));
307 assert_eq!(predicate!(domain <= 0), predicate!(view <= 1));
308 assert_eq!(predicate!(domain <= -3), predicate!(view <= -5));
309 }
310
311 #[test]
312 fn test_negated_variable_has_bounds_rounded_correctly() {
313 let domain = DomainId::new(0);
314 let view = AffineView::new(domain, -2, 0);
315
316 assert_eq!(predicate!(view <= -3), predicate!(domain >= 2));
317 assert_eq!(predicate!(view >= 5), predicate!(domain <= -3));
318 }
319}