gol_core/neighbors/
grid_donut.rs

1use super::util::{MarginPrimInt, PointPrimInt};
2use crate::cell::index::ToGridPointND;
3use crate::{BoardNeighborManager, GridPoint1D, GridPoint2D, GridPoint3D, GridPointND};
4use itertools::Itertools;
5use std::cmp::{max, min};
6
7pub struct NeighborsGridDonut<T> {
8    board_shape: Vec<T>,
9    should_repeat_margin: bool,
10    margins: Vec<(T, T)>,
11}
12
13impl<T> NeighborsGridDonut<T> {
14    pub fn new<I>(margin: T, board_shape: I) -> Self
15    where
16        T: Clone,
17        I: Iterator<Item = T>,
18    {
19        let margin_two_sides = vec![(margin.clone(), margin)];
20        Self {
21            should_repeat_margin: true,
22            margins: margin_two_sides,
23            board_shape: board_shape.collect(),
24        }
25    }
26
27    pub fn new_with_variable_margin<'a, 'b, I1, I2>(margins: I1, board_shape: I2) -> Self
28    where
29        'a: 'b,
30        T: 'a + Clone,
31        I1: Iterator<Item = &'b (T, T)>,
32        I2: Iterator<Item = T>,
33    {
34        let vec: Vec<(T, T)> = margins.map(|ele| (ele.0.clone(), ele.1.clone())).collect();
35        assert!(!vec.is_empty());
36        Self {
37            should_repeat_margin: false,
38            margins: vec,
39            board_shape: board_shape.collect(),
40        }
41    }
42
43    fn calc_grid_point_surrounding<U>(&self, idx: &GridPointND<U>) -> Vec<GridPointND<U>>
44    where
45        T: MarginPrimInt,
46        U: PointPrimInt,
47    {
48        let dim_ranges = self.calc_dim_ranges(idx);
49
50        // Expand dim ranges.
51        let mut indices_each_dim = Vec::with_capacity(dim_ranges.len());
52        for (ranges_1, ranges_2) in dim_ranges.iter() {
53            let mut cur = Vec::new();
54            let (cur_min, cur_max) = ranges_1;
55            for i in cur_min.to_i64().unwrap()..=cur_max.to_i64().unwrap() {
56                cur.push(U::from_i64(i).unwrap());
57            }
58            if ranges_2.is_some() {
59                let (cur_min, cur_max) = ranges_2.unwrap();
60                for i in cur_min.to_i64().unwrap()..=cur_max.to_i64().unwrap() {
61                    cur.push(U::from_i64(i).unwrap());
62                }
63            }
64            indices_each_dim.push(cur.into_iter());
65        }
66
67        let res = indices_each_dim
68            .into_iter()
69            .multi_cartesian_product()
70            .map(|ele| GridPointND::new(ele.iter()))
71            .filter(|ele| ele != idx)
72            .collect();
73        res
74    }
75
76    fn calc_dim_ranges<U>(&self, idx: &GridPointND<U>) -> Vec<((U, U), Option<(U, U)>)>
77    where
78        T: MarginPrimInt,
79        U: PointPrimInt,
80    {
81        let mut ranges = Vec::new();
82        for (i, dim_idx) in idx.indices().enumerate() {
83            let (neg, pos) = if self.should_repeat_margin {
84                self.margins.first().unwrap()
85            } else {
86                self.margins.get(i).unwrap()
87            };
88            let neg = U::from_usize(neg.to_usize().unwrap())
89                .expect("Index type too small to hold neighbor margin value.");
90            let pos = U::from_usize(pos.to_usize().unwrap())
91                .expect("Index type too small to hold neighbor margin value.");
92            let one = U::one();
93            let two = one + one;
94
95            let board_dim_len = U::from_usize(self.board_shape[i].to_usize().unwrap()).unwrap();
96            assert!(
97                board_dim_len.to_usize().unwrap()
98                    >= neg.to_usize().unwrap() + pos.to_usize().unwrap() + 1
99            );
100
101            let board_min = (board_dim_len / two).neg();
102            let board_max = board_dim_len / two
103                - if board_dim_len % two == one {
104                    U::zero()
105                } else {
106                    one
107                };
108
109            let mut wrapping_range: Option<(U, U)> = None;
110
111            let dim_idx_min_unchecked = dim_idx
112                .checked_sub(&neg)
113                .expect("Could not subtract points by margin value.");
114            let dim_idx_max_unchecked = dim_idx
115                .checked_add(&pos)
116                .expect("Could not add points by margin value.");
117            let dim_idx_min = max(board_min, dim_idx_min_unchecked);
118            let dim_idx_max = min(board_max, dim_idx_max_unchecked);
119
120            if dim_idx_min_unchecked < board_min {
121                let extension = dim_idx_min_unchecked - board_min;
122                wrapping_range = Some((board_max + extension + U::one(), board_max));
123            } else if dim_idx_max_unchecked > board_max {
124                let extension = dim_idx_max_unchecked - board_max;
125                wrapping_range = Some((board_min, board_min + extension - U::one()));
126            }
127
128            ranges.push(((dim_idx_min, dim_idx_max), wrapping_range));
129        }
130        ranges
131    }
132}
133
134impl<T, U> BoardNeighborManager<GridPointND<U>, std::vec::IntoIter<GridPointND<U>>>
135    for NeighborsGridDonut<T>
136where
137    T: MarginPrimInt,
138    U: PointPrimInt,
139{
140    fn get_neighbors_idx(&self, idx: &GridPointND<U>) -> std::vec::IntoIter<GridPointND<U>> {
141        self.calc_grid_point_surrounding(idx).into_iter()
142    }
143}
144
145impl<T, U> BoardNeighborManager<GridPoint3D<U>, std::vec::IntoIter<GridPoint3D<U>>>
146    for NeighborsGridDonut<T>
147where
148    T: MarginPrimInt,
149    U: PointPrimInt,
150{
151    fn get_neighbors_idx(&self, idx: &GridPoint3D<U>) -> std::vec::IntoIter<GridPoint3D<U>> {
152        let res: Vec<GridPoint3D<U>> = self
153            .calc_grid_point_surrounding(&idx.to_nd())
154            .iter()
155            .map(|ele| ele.to_3d().unwrap())
156            .collect();
157        res.into_iter()
158    }
159}
160
161impl<T, U> BoardNeighborManager<GridPoint2D<U>, std::vec::IntoIter<GridPoint2D<U>>>
162    for NeighborsGridDonut<T>
163where
164    T: MarginPrimInt,
165    U: PointPrimInt,
166{
167    fn get_neighbors_idx(&self, idx: &GridPoint2D<U>) -> std::vec::IntoIter<GridPoint2D<U>> {
168        let res: Vec<GridPoint2D<U>> = self
169            .calc_grid_point_surrounding(&idx.to_nd())
170            .iter()
171            .map(|ele| ele.to_2d().unwrap())
172            .collect();
173        res.into_iter()
174    }
175}
176
177impl<T, U> BoardNeighborManager<GridPoint1D<U>, std::vec::IntoIter<GridPoint1D<U>>>
178    for NeighborsGridDonut<T>
179where
180    T: MarginPrimInt,
181    U: PointPrimInt,
182{
183    fn get_neighbors_idx(&self, idx: &GridPoint1D<U>) -> std::vec::IntoIter<GridPoint1D<U>> {
184        let res: Vec<GridPoint1D<U>> = self
185            .calc_grid_point_surrounding(&idx.to_nd())
186            .iter()
187            .map(|ele| ele.to_1d().unwrap())
188            .collect();
189        res.into_iter()
190    }
191}
192
193#[cfg(test)]
194mod grid_donut_neighbor_test {
195    use crate::{BoardNeighborManager, GridPoint1D, GridPoint2D, NeighborsGridDonut};
196
197    #[test]
198    fn grid_donut_test_1d_1() {
199        let board_shape = vec![100usize];
200        let neighbor_calc = NeighborsGridDonut::new(1usize, board_shape.into_iter());
201        let point = GridPoint1D { x: 10 };
202        let neighbors: Vec<GridPoint1D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
203        assert_eq!(neighbors.len(), 2);
204        assert!(!neighbors.contains(&point));
205        assert!(neighbors.contains(&GridPoint1D { x: 9 }));
206        assert!(neighbors.contains(&GridPoint1D { x: 11 }));
207    }
208
209    #[test]
210    fn grid_donut_test_1d_2() {
211        let board_shape = vec![3usize];
212        let neighbor_calc = NeighborsGridDonut::new(1usize, board_shape.into_iter());
213        let point = GridPoint1D { x: 0 };
214        let neighbors: Vec<GridPoint1D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
215        assert_eq!(neighbors.len(), 2);
216        assert!(!neighbors.contains(&point));
217        assert!(neighbors.contains(&GridPoint1D { x: -1 }));
218        assert!(neighbors.contains(&GridPoint1D { x: 1 }));
219    }
220
221    #[test]
222    fn grid_donut_test_2d_1() {
223        let board_shape = vec![5usize, 5];
224        let neighbor_calc = NeighborsGridDonut::new(1usize, board_shape.into_iter());
225        let point = GridPoint2D { x: -2, y: -2 };
226        let neighbors: Vec<GridPoint2D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
227        assert_eq!(neighbors.len(), 8);
228        assert!(!neighbors.contains(&point));
229        assert!(neighbors.contains(&GridPoint2D { x: -2, y: -1 }));
230        assert!(neighbors.contains(&GridPoint2D { x: -1, y: -1 }));
231        assert!(neighbors.contains(&GridPoint2D { x: -1, y: -2 }));
232        assert!(neighbors.contains(&GridPoint2D { x: 2, y: -2 }));
233        assert!(neighbors.contains(&GridPoint2D { x: 2, y: -1 }));
234        assert!(neighbors.contains(&GridPoint2D { x: -2, y: 2 }));
235        assert!(neighbors.contains(&GridPoint2D { x: -1, y: 2 }));
236        assert!(neighbors.contains(&GridPoint2D { x: 2, y: 2 }));
237    }
238
239    #[test]
240    fn grid_donut_test_2d_2() {
241        let board_shape = vec![5usize, 5];
242        let neighbor_calc = NeighborsGridDonut::new(1usize, board_shape.into_iter());
243        let point = GridPoint2D { x: 2, y: 2 };
244        let neighbors: Vec<GridPoint2D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
245        assert_eq!(neighbors.len(), 8);
246        assert!(!neighbors.contains(&point));
247        assert!(neighbors.contains(&GridPoint2D { x: -2, y: -2 }));
248        assert!(neighbors.contains(&GridPoint2D { x: -2, y: 2 }));
249        assert!(neighbors.contains(&GridPoint2D { x: -2, y: 1 }));
250        assert!(neighbors.contains(&GridPoint2D { x: 2, y: -2 }));
251        assert!(neighbors.contains(&GridPoint2D { x: 1, y: -2 }));
252        assert!(neighbors.contains(&GridPoint2D { x: 1, y: 2 }));
253        assert!(neighbors.contains(&GridPoint2D { x: 1, y: 1 }));
254        assert!(neighbors.contains(&GridPoint2D { x: 2, y: 1 }));
255    }
256
257    #[test]
258    fn grid_donut_test_2d_3() {
259        let board_shape = vec![100usize, 49];
260        let neighbor_calc = NeighborsGridDonut::new(1usize, board_shape.into_iter());
261        let point = GridPoint2D { x: 0, y: -24 };
262        let neighbors: Vec<GridPoint2D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
263        assert_eq!(neighbors.len(), 8);
264        assert!(!neighbors.contains(&point));
265        assert!(neighbors.contains(&GridPoint2D { x: -1, y: -24 }));
266        assert!(neighbors.contains(&GridPoint2D { x: -1, y: -23 }));
267        assert!(neighbors.contains(&GridPoint2D { x: 0, y: -23 }));
268        assert!(neighbors.contains(&GridPoint2D { x: 1, y: -24 }));
269        assert!(neighbors.contains(&GridPoint2D { x: 1, y: -23 }));
270        assert!(neighbors.contains(&GridPoint2D { x: 1, y: 24 }));
271        assert!(neighbors.contains(&GridPoint2D { x: 0, y: 24 }));
272        assert!(neighbors.contains(&GridPoint2D { x: -1, y: 24 }));
273    }
274
275    #[test]
276    fn grid_donut_test_2d_4() {
277        let board_shape = vec![100usize, 50];
278        let neighbor_calc = NeighborsGridDonut::new(1usize, board_shape.into_iter());
279        let point = GridPoint2D { x: 0, y: -25 };
280        let neighbors: Vec<GridPoint2D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
281        assert_eq!(neighbors.len(), 8);
282        assert!(!neighbors.contains(&point));
283        assert!(neighbors.contains(&GridPoint2D { x: -1, y: -25 }));
284        assert!(neighbors.contains(&GridPoint2D { x: -1, y: -24 }));
285        assert!(neighbors.contains(&GridPoint2D { x: 0, y: -24 }));
286        assert!(neighbors.contains(&GridPoint2D { x: 1, y: -25 }));
287        assert!(neighbors.contains(&GridPoint2D { x: 1, y: -24 }));
288        assert!(neighbors.contains(&GridPoint2D { x: 1, y: 24 }));
289        assert!(neighbors.contains(&GridPoint2D { x: 0, y: 24 }));
290        assert!(neighbors.contains(&GridPoint2D { x: -1, y: 24 }));
291    }
292
293    #[test]
294    fn grid_donut_test_2d_5() {
295        let board_shape = vec![100usize, 49];
296        let neighbor_calc = NeighborsGridDonut::new(1usize, board_shape.into_iter());
297        let point = GridPoint2D { x: 0, y: 24 };
298        let neighbors: Vec<GridPoint2D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
299        assert_eq!(neighbors.len(), 8);
300        assert!(!neighbors.contains(&point));
301        assert!(neighbors.contains(&GridPoint2D { x: -1, y: 24 }));
302        assert!(neighbors.contains(&GridPoint2D { x: -1, y: 23 }));
303        assert!(neighbors.contains(&GridPoint2D { x: 0, y: 23 }));
304        assert!(neighbors.contains(&GridPoint2D { x: 1, y: 24 }));
305        assert!(neighbors.contains(&GridPoint2D { x: 1, y: 23 }));
306        assert!(neighbors.contains(&GridPoint2D { x: 1, y: -24 }));
307        assert!(neighbors.contains(&GridPoint2D { x: 0, y: -24 }));
308        assert!(neighbors.contains(&GridPoint2D { x: -1, y: -24 }));
309    }
310
311    #[test]
312    fn grid_donut_test_2d_6() {
313        let board_shape = vec![100usize, 50];
314        let neighbor_calc = NeighborsGridDonut::new(1usize, board_shape.into_iter());
315        let point = GridPoint2D { x: 0, y: 24 };
316        let neighbors: Vec<GridPoint2D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
317        assert_eq!(neighbors.len(), 8);
318        assert!(!neighbors.contains(&point));
319        assert!(neighbors.contains(&GridPoint2D { x: -1, y: 24 }));
320        assert!(neighbors.contains(&GridPoint2D { x: -1, y: 23 }));
321        assert!(neighbors.contains(&GridPoint2D { x: 0, y: 23 }));
322        assert!(neighbors.contains(&GridPoint2D { x: 1, y: 24 }));
323        assert!(neighbors.contains(&GridPoint2D { x: 1, y: 23 }));
324        assert!(neighbors.contains(&GridPoint2D { x: 1, y: -25 }));
325        assert!(neighbors.contains(&GridPoint2D { x: 0, y: -25 }));
326        assert!(neighbors.contains(&GridPoint2D { x: -1, y: -25 }));
327    }
328
329    #[test]
330    fn grid_donut_test_2d_7() {
331        let board_shape = vec![171usize, 50];
332        let neighbor_calc = NeighborsGridDonut::new(1usize, board_shape.into_iter());
333        let point = GridPoint2D { x: 0, y: 24 };
334        let neighbors: Vec<GridPoint2D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
335        assert_eq!(neighbors.len(), 8);
336        assert!(!neighbors.contains(&point));
337        assert!(neighbors.contains(&GridPoint2D { x: -1, y: 24 }));
338        assert!(neighbors.contains(&GridPoint2D { x: -1, y: 23 }));
339        assert!(neighbors.contains(&GridPoint2D { x: 0, y: 23 }));
340        assert!(neighbors.contains(&GridPoint2D { x: 1, y: 24 }));
341        assert!(neighbors.contains(&GridPoint2D { x: 1, y: 23 }));
342        assert!(neighbors.contains(&GridPoint2D { x: 1, y: -25 }));
343        assert!(neighbors.contains(&GridPoint2D { x: 0, y: -25 }));
344        assert!(neighbors.contains(&GridPoint2D { x: -1, y: -25 }));
345    }
346
347    #[test]
348    fn grid_donut_test_2d_8() {
349        let board_shape = vec![171usize, 50];
350        let neighbor_calc = NeighborsGridDonut::new(1usize, board_shape.into_iter());
351        let point = GridPoint2D { x: 1, y: 24 };
352        let neighbors: Vec<GridPoint2D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
353        assert_eq!(neighbors.len(), 8);
354        assert!(!neighbors.contains(&point));
355        assert!(neighbors.contains(&GridPoint2D { x: 0, y: 24 }));
356        assert!(neighbors.contains(&GridPoint2D { x: 0, y: 23 }));
357        assert!(neighbors.contains(&GridPoint2D { x: 1, y: 23 }));
358        assert!(neighbors.contains(&GridPoint2D { x: 2, y: 24 }));
359        assert!(neighbors.contains(&GridPoint2D { x: 2, y: 23 }));
360        assert!(neighbors.contains(&GridPoint2D { x: 2, y: -25 }));
361        assert!(neighbors.contains(&GridPoint2D { x: 1, y: -25 }));
362        assert!(neighbors.contains(&GridPoint2D { x: 0, y: -25 }));
363    }
364
365    #[test]
366    fn grid_donut_test_2d_9() {
367        let board_shape = vec![171usize, 50];
368        let neighbor_calc = NeighborsGridDonut::new(1usize, board_shape.into_iter());
369        for x in 0..171 {
370            for y in 0..50 {
371                let x_new = x - 171 / 2;
372                let y_new = y - 50 / 2;
373                let point = GridPoint2D::new(x_new, y_new);
374                let cur_neighbors: Vec<GridPoint2D<i32>> =
375                    neighbor_calc.get_neighbors_idx(&point).collect();
376                assert_eq!(cur_neighbors.len(), 8);
377            }
378        }
379    }
380}