1use super::util::{MarginPrimInt, PointPrimInt};
2use crate::{
3 BoardNeighborManager, GridPoint1D, GridPoint2D, GridPoint3D, GridPointND, NeighborMoore,
4};
5use itertools::izip;
6use std::convert::TryFrom;
7
8pub struct NeighborsGridSurround<T> {
9 should_repeat_margin: bool,
10 margins: Vec<(T, T)>,
11}
12
13impl<T> NeighborsGridSurround<T> {
14 pub fn new(margin: T) -> Self
34 where
35 T: Clone,
36 {
37 let margin_two_sides = vec![(margin.clone(), margin)];
38 Self {
39 should_repeat_margin: true,
40 margins: margin_two_sides,
41 }
42 }
43
44 pub fn new_with_variable_margin<'a, 'b, I>(margins: I) -> Self
59 where
60 'a: 'b,
61 T: 'a + Clone,
62 I: Iterator<Item = &'b (T, T)>,
63 {
64 let vec: Vec<(T, T)> = margins.map(|ele| (ele.0.clone(), ele.1.clone())).collect();
65 assert!(!vec.is_empty());
66 Self {
67 should_repeat_margin: false,
68 margins: vec,
69 }
70 }
71
72 fn calc_grid_point_surrounding<U>(&self, idx: &GridPointND<U>) -> Vec<GridPointND<U>>
73 where
74 T: MarginPrimInt,
75 U: PointPrimInt,
76 {
77 let (dim_ranges, dim_lens, volume) = self.calc_dim_ranges(idx);
78
79 let mut i_exclude = 0usize;
81 let idx_indices: Vec<&U> = idx.indices().collect();
82 let mut cur_volume = volume;
83 for (cur_idx, dim_len, (dim_min, _)) in izip!(&idx_indices, &dim_lens, &dim_ranges).rev() {
84 cur_volume /= dim_len;
85 i_exclude += (**cur_idx - *dim_min).to_usize().unwrap() * cur_volume;
86 }
87
88 let mut res = Vec::new();
89 for i in 0..volume {
90 if i == i_exclude {
91 continue;
92 }
93
94 let (mut cur_i, mut cur_vol) = (i, volume);
95 let mut cur_indices = Vec::with_capacity(dim_lens.len());
96
97 for ((dim_min, _), dim_len) in dim_ranges.iter().zip(dim_lens.iter()).rev() {
98 cur_vol /= dim_len;
99 let dim_idx = cur_i / cur_vol;
100 cur_indices.push(U::from_usize(dim_idx).unwrap() + *dim_min);
101 cur_i %= cur_vol;
102 }
103 res.push(GridPointND::new(cur_indices.iter().rev()));
104 }
105 res
106 }
107
108 fn calc_dim_ranges<U>(&self, idx: &GridPointND<U>) -> (Vec<(U, U)>, Vec<usize>, usize)
109 where
110 T: MarginPrimInt,
111 U: PointPrimInt,
112 {
113 let mut ranges = Vec::new();
114 let mut dim_lens = Vec::new();
115 let mut volume = 1;
116 for (i, dim_idx) in idx.indices().enumerate() {
117 let (neg, pos) = if self.should_repeat_margin {
118 self.margins.first().unwrap()
119 } else {
120 self.margins.get(i).unwrap()
121 };
122
123 let mut dim_idx_min = None;
124 for n in (0..=neg.to_usize().unwrap()).rev() {
125 let n_u = U::from_usize(n).unwrap();
126 match dim_idx.checked_sub(&n_u) {
127 Some(val) => {
128 dim_idx_min = Some(val);
129 break;
130 }
131 None => continue,
132 }
133 }
134
135 let mut dim_idx_max = None;
136 for n in (0..=pos.to_usize().unwrap()).rev() {
137 let n_u = U::from_usize(n).unwrap();
138 match dim_idx.checked_add(&n_u) {
139 Some(val) => {
140 dim_idx_max = Some(val);
141 break;
142 }
143 None => continue,
144 }
145 }
146
147 let dim_idx_min = dim_idx_min.unwrap();
151 let dim_idx_max = dim_idx_max.unwrap();
152
153 ranges.push((dim_idx_min, dim_idx_max));
154 let dim_len = (dim_idx_max - dim_idx_min + U::one()).to_usize().unwrap();
155 dim_lens.push(dim_len);
156 volume *= dim_len;
157 }
158 (ranges, dim_lens, volume)
159 }
160}
161
162impl<T, U> BoardNeighborManager<GridPointND<U>, std::vec::IntoIter<GridPointND<U>>>
163 for NeighborsGridSurround<T>
164where
165 T: MarginPrimInt,
166 U: PointPrimInt,
167{
168 fn get_neighbors_idx(&self, idx: &GridPointND<U>) -> std::vec::IntoIter<GridPointND<U>> {
169 self.calc_grid_point_surrounding(idx).into_iter()
170 }
171}
172
173impl<T, U> BoardNeighborManager<GridPoint3D<U>, std::vec::IntoIter<GridPoint3D<U>>>
174 for NeighborsGridSurround<T>
175where
176 T: MarginPrimInt,
177 U: PointPrimInt + TryFrom<T>,
178{
179 fn get_neighbors_idx(&self, idx: &GridPoint3D<U>) -> std::vec::IntoIter<GridPoint3D<U>> {
180 let one_t = T::one();
181 let (x_left, x_right) = self.margins.first().unwrap();
182 let (mut y_left, mut y_right) = self.margins.first().unwrap();
183 let (z_left, z_right) = self.margins.last().unwrap();
184 if !self.should_repeat_margin {
185 let y_margin = self.margins[2];
186 y_left = y_margin.0;
187 y_right = y_margin.1;
188 }
189 if x_left == &one_t
190 && x_right == &one_t
191 && (self.should_repeat_margin
192 || y_left == one_t && y_right == one_t && z_left == &one_t && z_right == &one_t)
193 {
194 return NeighborMoore::new().get_neighbors_idx(idx);
195 }
196 let x_left_u = match U::try_from(*x_left) {
197 Ok(val) => val,
198 Err(_) => panic!("Error casting number."),
199 };
200 let y_left_u = match U::try_from(y_left) {
201 Ok(val) => val,
202 Err(_) => panic!("Error casting number."),
203 };
204 let z_left_u = match U::try_from(*z_left) {
205 Ok(val) => val,
206 Err(_) => panic!("Error casting number."),
207 };
208 let mut res = Vec::new();
209 let width = (*x_left + one_t + *x_right).to_usize().unwrap();
210 let height = (y_left + one_t + y_right).to_usize().unwrap();
211 let depth = (*z_left + one_t + *z_right).to_usize().unwrap();
212 let skip_idx = x_left.to_usize().unwrap()
213 + width * y_left.to_usize().unwrap()
214 + width * height * z_left.to_usize().unwrap();
215 for i in 0..(width * height * depth) {
216 if i == skip_idx {
217 continue;
218 }
219 let cur_x = x_left_u + U::from_usize(i % width).unwrap();
220 let cur_y = y_left_u + U::from_usize(i / width).unwrap();
221 let cur_z = z_left_u + U::from_usize(i / (width * height)).unwrap();
222 res.push(GridPoint3D::new(cur_x, cur_y, cur_z));
223 }
224 res.into_iter()
225 }
226}
227
228impl<T, U> BoardNeighborManager<GridPoint2D<U>, std::vec::IntoIter<GridPoint2D<U>>>
229 for NeighborsGridSurround<T>
230where
231 T: MarginPrimInt,
232 U: PointPrimInt + TryFrom<T>,
233{
234 fn get_neighbors_idx(&self, idx: &GridPoint2D<U>) -> std::vec::IntoIter<GridPoint2D<U>> {
235 let one_t = T::one();
236 let (x_left, x_right) = self.margins.first().unwrap();
237 let (y_left, y_right) = self.margins.last().unwrap();
238 if x_left == &one_t
239 && x_right == &one_t
240 && (self.should_repeat_margin || y_left == &one_t && y_right == &one_t)
241 {
242 return NeighborMoore::new().get_neighbors_idx(idx);
243 }
244 let x_left_u = match U::try_from(*x_left) {
245 Ok(val) => val,
246 Err(_) => panic!("Error casting number."),
247 };
248 let y_left_u = match U::try_from(*y_left) {
249 Ok(val) => val,
250 Err(_) => panic!("Error casting number."),
251 };
252 let mut res = Vec::new();
253 let width = (*x_left + one_t + *x_right).to_usize().unwrap();
254 let height = (*y_left + one_t + *y_right).to_usize().unwrap();
255 let skip_idx = x_left.to_usize().unwrap() + width * y_left.to_usize().unwrap();
256 for i in 0..(width * height) {
257 if i == skip_idx {
258 continue;
259 }
260 let cur_x = x_left_u + U::from_usize(i % width).unwrap();
261 let cur_y = y_left_u + U::from_usize(i / width).unwrap();
262 res.push(GridPoint2D::new(cur_x, cur_y));
263 }
264 res.into_iter()
265 }
266}
267
268impl<T, U> BoardNeighborManager<GridPoint1D<U>, std::vec::IntoIter<GridPoint1D<U>>>
269 for NeighborsGridSurround<T>
270where
271 T: MarginPrimInt,
272 U: PointPrimInt + TryFrom<T>,
273{
274 fn get_neighbors_idx(&self, idx: &GridPoint1D<U>) -> std::vec::IntoIter<GridPoint1D<U>> {
275 let (one_t, one_u) = (T::one(), U::one());
276 let (left, right) = self.margins.first().unwrap();
277 if left == &one_t && right == &one_t {
278 return NeighborMoore::new().get_neighbors_idx(idx);
279 }
280 let left = match U::try_from(*left) {
281 Ok(val) => val,
282 Err(_) => panic!("Error casting number."),
283 };
284 let right = match U::try_from(*right) {
285 Ok(val) => val,
286 Err(_) => panic!("Error casting number."),
287 };
288 let mut res = Vec::new();
289 let left_most_idx = idx.x - left;
290 for i in 0..(left + one_u + right).to_usize().unwrap() {
291 let cur_i = U::from_usize(i).unwrap();
292 let cur_x = cur_i + left_most_idx;
293 if cur_x == idx.x {
294 continue;
295 }
296 res.push(GridPoint1D::new(cur_x));
297 }
298 res.into_iter()
299 }
300}
301
302#[cfg(test)]
303mod grid_surrounding_neighbor_test {
304 use crate::{
305 BoardNeighborManager, GridPoint1D, GridPoint2D, GridPoint3D, GridPointND,
306 NeighborsGridSurround,
307 };
308
309 #[test]
310 fn grid_surrounding_test_1d_1() {
311 let neighbor_calc = NeighborsGridSurround::new(1usize);
312 let point = GridPoint1D { x: 10 };
313 let neighbors: Vec<GridPoint1D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
314 assert_eq!(neighbors.len(), 2);
315 assert!(!neighbors.contains(&point));
316 assert!(neighbors.contains(&GridPoint1D { x: 9 }));
317 assert!(neighbors.contains(&GridPoint1D { x: 11 }));
318 }
319
320 #[test]
321 fn grid_surrounding_test_1d_2() {
322 let neighbor_calc = NeighborsGridSurround::new(1usize);
323 let point = GridPoint1D { x: 0 };
324 let neighbors: Vec<GridPoint1D<i64>> = neighbor_calc.get_neighbors_idx(&point).collect();
325 assert_eq!(neighbors.len(), 2);
326 assert!(!neighbors.contains(&point));
327 assert!(neighbors.contains(&GridPoint1D { x: 1 }));
328 }
329
330 #[test]
331 fn grid_surrounding_test_2d_1() {
332 let neighbor_calc = NeighborsGridSurround::new(1usize);
333 let point = GridPoint2D { x: 10, y: 5 };
334 let neighbors: Vec<GridPoint2D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
335 assert_eq!(neighbors.len(), 8);
336 assert!(!neighbors.contains(&point));
337 }
338
339 #[test]
340 fn grid_surrounding_test_2d_2() {
341 let neighbor_calc = NeighborsGridSurround::new(1usize);
342 let point = GridPoint2D { x: 0, y: 0 };
343 let neighbors: Vec<GridPoint2D<i64>> = neighbor_calc.get_neighbors_idx(&point).collect();
344 assert_eq!(neighbors.len(), 8);
345 assert!(!neighbors.contains(&point));
346 }
347
348 #[test]
349 fn grid_surrounding_test_2d_3() {
350 let neighbor_calc = NeighborsGridSurround::new(1usize);
351 let point = GridPoint2D { x: 0, y: 1 };
352 let neighbors: Vec<GridPoint2D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
353 assert_eq!(neighbors.len(), 8);
354 assert!(!neighbors.contains(&point));
355 }
356
357 #[test]
358 fn grid_surrounding_test_3d_1() {
359 let neighbor_calc = NeighborsGridSurround::new(1usize);
360 let point = GridPoint3D { x: 3, y: 10, z: 5 };
361 let neighbors: Vec<GridPoint3D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
362 assert_eq!(neighbors.len(), 26);
363 assert!(!neighbors.contains(&point));
364 }
365
366 #[test]
367 fn grid_surrounding_test_3d_2() {
368 let neighbor_calc = NeighborsGridSurround::new(2usize);
369 let point = GridPoint3D { x: 0, y: 0, z: 0 };
370 let neighbors: Vec<GridPoint3D<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
371 assert_eq!(neighbors.len(), 124);
372 assert!(!neighbors.contains(&point));
373 }
374
375 #[test]
376 fn grid_surrounding_test_3d_3() {
377 let neighbor_calc = NeighborsGridSurround::new(2usize);
378 let point_1 = GridPoint3D { x: 0, y: 1, z: 1 };
379 let point_2 = GridPoint3D { x: 1, y: 0, z: 1 };
380 let point_3 = GridPoint3D { x: 1, y: 1, z: 0 };
381 let neighbors_1: Vec<GridPoint3D<i32>> =
382 neighbor_calc.get_neighbors_idx(&point_1).collect();
383 let neighbors_2: Vec<GridPoint3D<i32>> =
384 neighbor_calc.get_neighbors_idx(&point_2).collect();
385 let neighbors_3: Vec<GridPoint3D<i32>> =
386 neighbor_calc.get_neighbors_idx(&point_3).collect();
387 assert_eq!(neighbors_1.len(), 124);
388 assert_eq!(neighbors_2.len(), neighbors_1.len());
389 assert_eq!(neighbors_3.len(), neighbors_1.len());
390 assert!(!neighbors_1.contains(&point_1));
391 assert!(!neighbors_2.contains(&point_2));
392 assert!(!neighbors_3.contains(&point_3));
393 }
394
395 #[test]
396 fn grid_surrounding_test_4d_1() {
397 let neighbor_calc = NeighborsGridSurround::new(2usize);
398 let point = GridPointND::new(vec![0, 0, 0, 0].iter());
399 let neighbors: Vec<GridPointND<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
400 assert_eq!(neighbors.len(), 624);
401 assert!(!neighbors.contains(&point));
402 }
403
404 #[test]
405 fn grid_surrounding_test_4d_2() {
406 let mut margins = Vec::new();
407 margins.push((100usize, 2)); margins.push((50, 1)); margins.push((10, 2)); margins.push((0, 9)); let neighbor_calc = NeighborsGridSurround::new_with_variable_margin(margins.iter());
412 let point = GridPointND::new(vec![0, 0, 1, 0].iter());
413 let neighbors: Vec<GridPointND<i32>> = neighbor_calc.get_neighbors_idx(&point).collect();
414 assert_eq!(neighbors.len(), 696279);
415 assert!(!neighbors.contains(&point));
416 }
417}