1use super::util::{MarginPrimInt, PointPrimInt};
2use crate::cell::index::ToGridPointND;
3use crate::{BoardNeighborManager, GridPoint1D, GridPoint2D, GridPoint3D, GridPointND};
4use itertools::Itertools;
5use std::cmp::{max, min};
6
7pub struct NeighborsGridDonut<T> {
8 board_shape: Vec<T>,
9 should_repeat_margin: bool,
10 margins: Vec<(T, T)>,
11}
12
13impl<T> NeighborsGridDonut<T> {
14 pub fn new<I>(margin: T, board_shape: I) -> Self
15 where
16 T: Clone,
17 I: Iterator<Item = T>,
18 {
19 let margin_two_sides = vec![(margin.clone(), margin)];
20 Self {
21 should_repeat_margin: true,
22 margins: margin_two_sides,
23 board_shape: board_shape.collect(),
24 }
25 }
26
27 pub fn new_with_variable_margin<'a, 'b, I1, I2>(margins: I1, board_shape: I2) -> Self
28 where
29 'a: 'b,
30 T: 'a + Clone,
31 I1: Iterator<Item = &'b (T, T)>,
32 I2: Iterator<Item = T>,
33 {
34 let vec: Vec<(T, T)> = margins.map(|ele| (ele.0.clone(), ele.1.clone())).collect();
35 assert!(!vec.is_empty());
36 Self {
37 should_repeat_margin: false,
38 margins: vec,
39 board_shape: board_shape.collect(),
40 }
41 }
42
43 fn calc_grid_point_surrounding<U>(&self, idx: &GridPointND<U>) -> Vec<GridPointND<U>>
44 where
45 T: MarginPrimInt,
46 U: PointPrimInt,
47 {
48 let dim_ranges = self.calc_dim_ranges(idx);
49
50 let mut indices_each_dim = Vec::with_capacity(dim_ranges.len());
52 for (ranges_1, ranges_2) in dim_ranges.iter() {
53 let mut cur = Vec::new();
54 let (cur_min, cur_max) = ranges_1;
55 for i in cur_min.to_i64().unwrap()..=cur_max.to_i64().unwrap() {
56 cur.push(U::from_i64(i).unwrap());
57 }
58 if ranges_2.is_some() {
59 let (cur_min, cur_max) = ranges_2.unwrap();
60 for i in cur_min.to_i64().unwrap()..=cur_max.to_i64().unwrap() {
61 cur.push(U::from_i64(i).unwrap());
62 }
63 }
64 indices_each_dim.push(cur.into_iter());
65 }
66
67 let res = indices_each_dim
68 .into_iter()
69 .multi_cartesian_product()
70 .map(|ele| GridPointND::new(ele.iter()))
71 .filter(|ele| ele != idx)
72 .collect();
73 res
74 }
75
76 fn calc_dim_ranges<U>(&self, idx: &GridPointND<U>) -> Vec<((U, U), Option<(U, U)>)>
77 where
78 T: MarginPrimInt,
79 U: PointPrimInt,
80 {
81 let mut ranges = Vec::new();
82 for (i, dim_idx) in idx.indices().enumerate() {
83 let (neg, pos) = if self.should_repeat_margin {
84 self.margins.first().unwrap()
85 } else {
86 self.margins.get(i).unwrap()
87 };
88 let neg = U::from_usize(neg.to_usize().unwrap())
89 .expect("Index type too small to hold neighbor margin value.");
90 let pos = U::from_usize(pos.to_usize().unwrap())
91 .expect("Index type too small to hold neighbor margin value.");
92 let one = U::one();
93 let two = one + one;
94
95 let board_dim_len = U::from_usize(self.board_shape[i].to_usize().unwrap()).unwrap();
96 assert!(
97 board_dim_len.to_usize().unwrap()
98 >= neg.to_usize().unwrap() + pos.to_usize().unwrap() + 1
99 );
100
101 let board_min = (board_dim_len / two).neg();
102 let board_max = board_dim_len / two
103 - if board_dim_len % two == one {
104 U::zero()
105 } else {
106 one
107 };
108
109 let mut wrapping_range: Option<(U, U)> = None;
110
111 let dim_idx_min_unchecked = dim_idx
112 .checked_sub(&neg)
113 .expect("Could not subtract points by margin value.");
114 let dim_idx_max_unchecked = dim_idx
115 .checked_add(&pos)
116 .expect("Could not add points by margin value.");
117 let dim_idx_min = max(board_min, dim_idx_min_unchecked);
118 let dim_idx_max = min(board_max, dim_idx_max_unchecked);
119
120 if dim_idx_min_unchecked < board_min {
121 let extension = dim_idx_min_unchecked - board_min;
122 wrapping_range = Some((board_max + extension + U::one(), board_max));
123 } else if dim_idx_max_unchecked > board_max {
124 let extension = dim_idx_max_unchecked - board_max;
125 wrapping_range = Some((board_min, board_min + extension - U::one()));
126 }
127
128 ranges.push(((dim_idx_min, dim_idx_max), wrapping_range));
129 }
130 ranges
131 }
132}
133
134impl<T, U> BoardNeighborManager<GridPointND<U>, std::vec::IntoIter<GridPointND<U>>>
135 for NeighborsGridDonut<T>
136where
137 T: MarginPrimInt,
138 U: PointPrimInt,
139{
140 fn get_neighbors_idx(&self, idx: &GridPointND<U>) -> std::vec::IntoIter<GridPointND<U>> {
141 self.calc_grid_point_surrounding(idx).into_iter()
142 }
143}
144
145impl<T, U> BoardNeighborManager<GridPoint3D<U>, std::vec::IntoIter<GridPoint3D<U>>>
146 for NeighborsGridDonut<T>
147where
148 T: MarginPrimInt,
149 U: PointPrimInt,
150{
151 fn get_neighbors_idx(&self, idx: &GridPoint3D<U>) -> std::vec::IntoIter<GridPoint3D<U>> {
152 let res: Vec<GridPoint3D<U>> = self
153 .calc_grid_point_surrounding(&idx.to_nd())
154 .iter()
155 .map(|ele| ele.to_3d().unwrap())
156 .collect();
157 res.into_iter()
158 }
159}
160
161impl<T, U> BoardNeighborManager<GridPoint2D<U>, std::vec::IntoIter<GridPoint2D<U>>>
162 for NeighborsGridDonut<T>
163where
164 T: MarginPrimInt,
165 U: PointPrimInt,
166{
167 fn get_neighbors_idx(&self, idx: &GridPoint2D<U>) -> std::vec::IntoIter<GridPoint2D<U>> {
168 let res: Vec<GridPoint2D<U>> = self
169 .calc_grid_point_surrounding(&idx.to_nd())
170 .iter()
171 .map(|ele| ele.to_2d().unwrap())
172 .collect();
173 res.into_iter()
174 }
175}
176
177impl<T, U> BoardNeighborManager<GridPoint1D<U>, std::vec::IntoIter<GridPoint1D<U>>>
178 for NeighborsGridDonut<T>
179where
180 T: MarginPrimInt,
181 U: PointPrimInt,
182{
183 fn get_neighbors_idx(&self, idx: &GridPoint1D<U>) -> std::vec::IntoIter<GridPoint1D<U>> {
184 let res: Vec<GridPoint1D<U>> = self
185 .calc_grid_point_surrounding(&idx.to_nd())
186 .iter()
187 .map(|ele| ele.to_1d().unwrap())
188 .collect();
189 res.into_iter()
190 }
191}
192
193#[cfg(test)]
194mod grid_donut_neighbor_test {
195 use crate::{BoardNeighborManager, GridPoint1D, GridPoint2D, NeighborsGridDonut};
196
197 #[test]
198 fn grid_donut_test_1d_1() {
199 let board_shape = vec![100usize];
200 let neighbor_calc = NeighborsGridDonut::new(1usize, board_shape.into_iter());
201 let point = GridPoint1D { x: 10 };
202 let neighbors: Vec<GridPoint1D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
203 assert_eq!(neighbors.len(), 2);
204 assert!(!neighbors.contains(&point));
205 assert!(neighbors.contains(&GridPoint1D { x: 9 }));
206 assert!(neighbors.contains(&GridPoint1D { x: 11 }));
207 }
208
209 #[test]
210 fn grid_donut_test_1d_2() {
211 let board_shape = vec![3usize];
212 let neighbor_calc = NeighborsGridDonut::new(1usize, board_shape.into_iter());
213 let point = GridPoint1D { x: 0 };
214 let neighbors: Vec<GridPoint1D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
215 assert_eq!(neighbors.len(), 2);
216 assert!(!neighbors.contains(&point));
217 assert!(neighbors.contains(&GridPoint1D { x: -1 }));
218 assert!(neighbors.contains(&GridPoint1D { x: 1 }));
219 }
220
221 #[test]
222 fn grid_donut_test_2d_1() {
223 let board_shape = vec![5usize, 5];
224 let neighbor_calc = NeighborsGridDonut::new(1usize, board_shape.into_iter());
225 let point = GridPoint2D { x: -2, y: -2 };
226 let neighbors: Vec<GridPoint2D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
227 assert_eq!(neighbors.len(), 8);
228 assert!(!neighbors.contains(&point));
229 assert!(neighbors.contains(&GridPoint2D { x: -2, y: -1 }));
230 assert!(neighbors.contains(&GridPoint2D { x: -1, y: -1 }));
231 assert!(neighbors.contains(&GridPoint2D { x: -1, y: -2 }));
232 assert!(neighbors.contains(&GridPoint2D { x: 2, y: -2 }));
233 assert!(neighbors.contains(&GridPoint2D { x: 2, y: -1 }));
234 assert!(neighbors.contains(&GridPoint2D { x: -2, y: 2 }));
235 assert!(neighbors.contains(&GridPoint2D { x: -1, y: 2 }));
236 assert!(neighbors.contains(&GridPoint2D { x: 2, y: 2 }));
237 }
238
239 #[test]
240 fn grid_donut_test_2d_2() {
241 let board_shape = vec![5usize, 5];
242 let neighbor_calc = NeighborsGridDonut::new(1usize, board_shape.into_iter());
243 let point = GridPoint2D { x: 2, y: 2 };
244 let neighbors: Vec<GridPoint2D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
245 assert_eq!(neighbors.len(), 8);
246 assert!(!neighbors.contains(&point));
247 assert!(neighbors.contains(&GridPoint2D { x: -2, y: -2 }));
248 assert!(neighbors.contains(&GridPoint2D { x: -2, y: 2 }));
249 assert!(neighbors.contains(&GridPoint2D { x: -2, y: 1 }));
250 assert!(neighbors.contains(&GridPoint2D { x: 2, y: -2 }));
251 assert!(neighbors.contains(&GridPoint2D { x: 1, y: -2 }));
252 assert!(neighbors.contains(&GridPoint2D { x: 1, y: 2 }));
253 assert!(neighbors.contains(&GridPoint2D { x: 1, y: 1 }));
254 assert!(neighbors.contains(&GridPoint2D { x: 2, y: 1 }));
255 }
256
257 #[test]
258 fn grid_donut_test_2d_3() {
259 let board_shape = vec![100usize, 49];
260 let neighbor_calc = NeighborsGridDonut::new(1usize, board_shape.into_iter());
261 let point = GridPoint2D { x: 0, y: -24 };
262 let neighbors: Vec<GridPoint2D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
263 assert_eq!(neighbors.len(), 8);
264 assert!(!neighbors.contains(&point));
265 assert!(neighbors.contains(&GridPoint2D { x: -1, y: -24 }));
266 assert!(neighbors.contains(&GridPoint2D { x: -1, y: -23 }));
267 assert!(neighbors.contains(&GridPoint2D { x: 0, y: -23 }));
268 assert!(neighbors.contains(&GridPoint2D { x: 1, y: -24 }));
269 assert!(neighbors.contains(&GridPoint2D { x: 1, y: -23 }));
270 assert!(neighbors.contains(&GridPoint2D { x: 1, y: 24 }));
271 assert!(neighbors.contains(&GridPoint2D { x: 0, y: 24 }));
272 assert!(neighbors.contains(&GridPoint2D { x: -1, y: 24 }));
273 }
274
275 #[test]
276 fn grid_donut_test_2d_4() {
277 let board_shape = vec![100usize, 50];
278 let neighbor_calc = NeighborsGridDonut::new(1usize, board_shape.into_iter());
279 let point = GridPoint2D { x: 0, y: -25 };
280 let neighbors: Vec<GridPoint2D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
281 assert_eq!(neighbors.len(), 8);
282 assert!(!neighbors.contains(&point));
283 assert!(neighbors.contains(&GridPoint2D { x: -1, y: -25 }));
284 assert!(neighbors.contains(&GridPoint2D { x: -1, y: -24 }));
285 assert!(neighbors.contains(&GridPoint2D { x: 0, y: -24 }));
286 assert!(neighbors.contains(&GridPoint2D { x: 1, y: -25 }));
287 assert!(neighbors.contains(&GridPoint2D { x: 1, y: -24 }));
288 assert!(neighbors.contains(&GridPoint2D { x: 1, y: 24 }));
289 assert!(neighbors.contains(&GridPoint2D { x: 0, y: 24 }));
290 assert!(neighbors.contains(&GridPoint2D { x: -1, y: 24 }));
291 }
292
293 #[test]
294 fn grid_donut_test_2d_5() {
295 let board_shape = vec![100usize, 49];
296 let neighbor_calc = NeighborsGridDonut::new(1usize, board_shape.into_iter());
297 let point = GridPoint2D { x: 0, y: 24 };
298 let neighbors: Vec<GridPoint2D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
299 assert_eq!(neighbors.len(), 8);
300 assert!(!neighbors.contains(&point));
301 assert!(neighbors.contains(&GridPoint2D { x: -1, y: 24 }));
302 assert!(neighbors.contains(&GridPoint2D { x: -1, y: 23 }));
303 assert!(neighbors.contains(&GridPoint2D { x: 0, y: 23 }));
304 assert!(neighbors.contains(&GridPoint2D { x: 1, y: 24 }));
305 assert!(neighbors.contains(&GridPoint2D { x: 1, y: 23 }));
306 assert!(neighbors.contains(&GridPoint2D { x: 1, y: -24 }));
307 assert!(neighbors.contains(&GridPoint2D { x: 0, y: -24 }));
308 assert!(neighbors.contains(&GridPoint2D { x: -1, y: -24 }));
309 }
310
311 #[test]
312 fn grid_donut_test_2d_6() {
313 let board_shape = vec![100usize, 50];
314 let neighbor_calc = NeighborsGridDonut::new(1usize, board_shape.into_iter());
315 let point = GridPoint2D { x: 0, y: 24 };
316 let neighbors: Vec<GridPoint2D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
317 assert_eq!(neighbors.len(), 8);
318 assert!(!neighbors.contains(&point));
319 assert!(neighbors.contains(&GridPoint2D { x: -1, y: 24 }));
320 assert!(neighbors.contains(&GridPoint2D { x: -1, y: 23 }));
321 assert!(neighbors.contains(&GridPoint2D { x: 0, y: 23 }));
322 assert!(neighbors.contains(&GridPoint2D { x: 1, y: 24 }));
323 assert!(neighbors.contains(&GridPoint2D { x: 1, y: 23 }));
324 assert!(neighbors.contains(&GridPoint2D { x: 1, y: -25 }));
325 assert!(neighbors.contains(&GridPoint2D { x: 0, y: -25 }));
326 assert!(neighbors.contains(&GridPoint2D { x: -1, y: -25 }));
327 }
328
329 #[test]
330 fn grid_donut_test_2d_7() {
331 let board_shape = vec![171usize, 50];
332 let neighbor_calc = NeighborsGridDonut::new(1usize, board_shape.into_iter());
333 let point = GridPoint2D { x: 0, y: 24 };
334 let neighbors: Vec<GridPoint2D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
335 assert_eq!(neighbors.len(), 8);
336 assert!(!neighbors.contains(&point));
337 assert!(neighbors.contains(&GridPoint2D { x: -1, y: 24 }));
338 assert!(neighbors.contains(&GridPoint2D { x: -1, y: 23 }));
339 assert!(neighbors.contains(&GridPoint2D { x: 0, y: 23 }));
340 assert!(neighbors.contains(&GridPoint2D { x: 1, y: 24 }));
341 assert!(neighbors.contains(&GridPoint2D { x: 1, y: 23 }));
342 assert!(neighbors.contains(&GridPoint2D { x: 1, y: -25 }));
343 assert!(neighbors.contains(&GridPoint2D { x: 0, y: -25 }));
344 assert!(neighbors.contains(&GridPoint2D { x: -1, y: -25 }));
345 }
346
347 #[test]
348 fn grid_donut_test_2d_8() {
349 let board_shape = vec![171usize, 50];
350 let neighbor_calc = NeighborsGridDonut::new(1usize, board_shape.into_iter());
351 let point = GridPoint2D { x: 1, y: 24 };
352 let neighbors: Vec<GridPoint2D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
353 assert_eq!(neighbors.len(), 8);
354 assert!(!neighbors.contains(&point));
355 assert!(neighbors.contains(&GridPoint2D { x: 0, y: 24 }));
356 assert!(neighbors.contains(&GridPoint2D { x: 0, y: 23 }));
357 assert!(neighbors.contains(&GridPoint2D { x: 1, y: 23 }));
358 assert!(neighbors.contains(&GridPoint2D { x: 2, y: 24 }));
359 assert!(neighbors.contains(&GridPoint2D { x: 2, y: 23 }));
360 assert!(neighbors.contains(&GridPoint2D { x: 2, y: -25 }));
361 assert!(neighbors.contains(&GridPoint2D { x: 1, y: -25 }));
362 assert!(neighbors.contains(&GridPoint2D { x: 0, y: -25 }));
363 }
364
365 #[test]
366 fn grid_donut_test_2d_9() {
367 let board_shape = vec![171usize, 50];
368 let neighbor_calc = NeighborsGridDonut::new(1usize, board_shape.into_iter());
369 for x in 0..171 {
370 for y in 0..50 {
371 let x_new = x - 171 / 2;
372 let y_new = y - 50 / 2;
373 let point = GridPoint2D::new(x_new, y_new);
374 let cur_neighbors: Vec<GridPoint2D<i32>> =
375 neighbor_calc.get_neighbors_idx(&point).collect();
376 assert_eq!(cur_neighbors.len(), 8);
377 }
378 }
379 }
380}