1use super::types::Point;
8use nalgebra::Vector3;
9
10const SENTINEL: u32 = u32::MAX;
12
13#[derive(Debug, Clone)]
24pub struct Grid<T> {
25 cell_size: f64,
27 origin: Point,
29 dims: Vector3<usize>,
31 head: Vec<u32>,
33 next: Vec<u32>,
35 items: Vec<(Point, T)>,
37}
38
39impl<T> Grid<T> {
40 pub fn new(items: impl IntoIterator<Item = (Point, T)>, cell_size: f64) -> Self {
54 assert!(cell_size > 0.0, "Cell size must be positive");
55
56 let input_items: Vec<_> = items.into_iter().collect();
57 let num_items = input_items.len();
58
59 if num_items == 0 {
60 return Self {
61 cell_size,
62 origin: Point::origin(),
63 dims: Vector3::zeros(),
64 head: Vec::new(),
65 next: Vec::new(),
66 items: Vec::new(),
67 };
68 }
69
70 let mut min = Point::new(f64::MAX, f64::MAX, f64::MAX);
71 let mut max = Point::new(f64::MIN, f64::MIN, f64::MIN);
72
73 for (pos, _) in &input_items {
74 min = min.inf(pos);
75 max = max.sup(pos);
76 }
77
78 let epsilon = 1e-6;
79 max += Vector3::new(epsilon, epsilon, epsilon);
80
81 let extent = max - min;
82 let dims = Vector3::new(
83 (extent.x / cell_size).ceil() as usize,
84 (extent.y / cell_size).ceil() as usize,
85 (extent.z / cell_size).ceil() as usize,
86 );
87
88 let total_cells = dims.x * dims.y * dims.z;
89
90 let mut head = vec![SENTINEL; total_cells];
91 let mut next = vec![SENTINEL; num_items];
92 let mut stored_items = Vec::with_capacity(num_items);
93
94 for (i, (pos, item)) in input_items.into_iter().enumerate() {
95 stored_items.push((pos, item));
96
97 if let Some(cell_idx) = Self::get_cell_index_static(&pos, dims, min, cell_size) {
98 next[i] = head[cell_idx];
99 head[cell_idx] = i as u32;
100 }
101 }
102
103 Self {
104 cell_size,
105 origin: min,
106 dims,
107 head,
108 next,
109 items: stored_items,
110 }
111 }
112
113 fn get_cell_index_static(
115 pos: &Point,
116 dims: Vector3<usize>,
117 origin: Point,
118 cell_size: f64,
119 ) -> Option<usize> {
120 if pos.x < origin.x || pos.y < origin.y || pos.z < origin.z {
121 return None;
122 }
123
124 let offset = pos - origin;
125 let x = (offset.x / cell_size).floor() as usize;
126 let y = (offset.y / cell_size).floor() as usize;
127 let z = (offset.z / cell_size).floor() as usize;
128
129 if x >= dims.x || y >= dims.y || z >= dims.z {
130 return None;
131 }
132
133 Some(x + y * dims.x + z * dims.x * dims.y)
134 }
135
136 pub fn neighbors<'a>(&'a self, center: &Point, radius: f64) -> GridNeighborhood<'a, T> {
146 if self.items.is_empty() {
147 return GridNeighborhood {
148 grid: self,
149 min_x: 0,
150 max_x: 0,
151 min_y: 0,
152 max_y: 0,
153 max_z: 0,
154 curr_x: 0,
155 curr_y: 0,
156 curr_z: 1,
157 curr_item_idx: SENTINEL,
158 center: *center,
159 radius_sq: radius * radius,
160 };
161 }
162
163 let min_idx = self.get_grid_coords(&(center - Vector3::new(radius, radius, radius)));
164 let max_idx = self.get_grid_coords(&(center + Vector3::new(radius, radius, radius)));
165
166 let (min_x, min_y, min_z) = min_idx;
167 let (max_x, max_y, max_z) = max_idx;
168
169 GridNeighborhood {
170 grid: self,
171 min_x,
172 max_x,
173 min_y,
174 max_y,
175 max_z,
176 curr_x: min_x,
177 curr_y: min_y,
178 curr_z: min_z,
179 curr_item_idx: SENTINEL,
180 center: *center,
181 radius_sq: radius * radius,
182 }
183 }
184
185 fn get_grid_coords(&self, pos: &Point) -> (usize, usize, usize) {
187 let offset = pos - self.origin;
188 let x = (offset.x / self.cell_size).floor() as isize;
189 let y = (offset.y / self.cell_size).floor() as isize;
190 let z = (offset.z / self.cell_size).floor() as isize;
191
192 (
193 x.clamp(0, (self.dims.x as isize) - 1) as usize,
194 y.clamp(0, (self.dims.y as isize) - 1) as usize,
195 z.clamp(0, (self.dims.z as isize) - 1) as usize,
196 )
197 }
198
199 pub fn has_neighbor<F>(&self, point: &Point, radius: f64, mut predicate: F) -> bool
209 where
210 F: FnMut(&T) -> bool,
211 {
212 for item in self.neighbors(point, radius) {
213 if predicate(item) {
214 return true;
215 }
216 }
217 false
218 }
219}
220
221pub struct GridNeighborhood<'a, T> {
226 grid: &'a Grid<T>,
227 min_x: usize,
228 max_x: usize,
229 min_y: usize,
230 max_y: usize,
231 max_z: usize,
232 curr_x: usize,
233 curr_y: usize,
234 curr_z: usize,
235 curr_item_idx: u32,
236 center: Point,
237 radius_sq: f64,
238}
239
240impl<'a, T> GridNeighborhood<'a, T> {
241 pub fn exact(self) -> impl Iterator<Item = (Point, &'a T)> + 'a {
246 ExactGridNeighborhood { inner: self }
247 }
248}
249
250pub struct ExactGridNeighborhood<'a, T> {
256 inner: GridNeighborhood<'a, T>,
257}
258
259impl<'a, T> Iterator for ExactGridNeighborhood<'a, T> {
260 type Item = (Point, &'a T);
261
262 fn next(&mut self) -> Option<Self::Item> {
263 loop {
264 if self.inner.curr_item_idx != SENTINEL {
265 let (pos, item) = &self.inner.grid.items[self.inner.curr_item_idx as usize];
266 self.inner.curr_item_idx = self.inner.grid.next[self.inner.curr_item_idx as usize];
267
268 if nalgebra::distance_squared(pos, &self.inner.center) <= self.inner.radius_sq {
269 return Some((*pos, item));
270 }
271 continue;
272 }
273
274 if self.inner.curr_x > self.inner.max_x {
275 self.inner.curr_x = self.inner.min_x;
276 self.inner.curr_y += 1;
277 }
278 if self.inner.curr_y > self.inner.max_y {
279 self.inner.curr_y = self.inner.min_y;
280 self.inner.curr_z += 1;
281 }
282 if self.inner.curr_z > self.inner.max_z {
283 return None;
284 }
285
286 let cell_idx = self.inner.curr_x
287 + self.inner.curr_y * self.inner.grid.dims.x
288 + self.inner.curr_z * self.inner.grid.dims.x * self.inner.grid.dims.y;
289
290 self.inner.curr_x += 1;
291
292 if cell_idx < self.inner.grid.head.len() {
293 self.inner.curr_item_idx = self.inner.grid.head[cell_idx];
294 }
295 }
296 }
297}
298
299impl<'a, T> Iterator for GridNeighborhood<'a, T> {
300 type Item = &'a T;
301
302 fn next(&mut self) -> Option<Self::Item> {
303 loop {
304 if self.curr_item_idx != SENTINEL {
305 let (_, item) = &self.grid.items[self.curr_item_idx as usize];
306 self.curr_item_idx = self.grid.next[self.curr_item_idx as usize];
307 return Some(item);
308 }
309
310 if self.curr_x > self.max_x {
311 self.curr_x = self.min_x;
312 self.curr_y += 1;
313 }
314 if self.curr_y > self.max_y {
315 self.curr_y = self.min_y;
316 self.curr_z += 1;
317 }
318 if self.curr_z > self.max_z {
319 return None;
320 }
321
322 let cell_idx = self.curr_x
323 + self.curr_y * self.grid.dims.x
324 + self.curr_z * self.grid.dims.x * self.grid.dims.y;
325
326 self.curr_x += 1;
327
328 if cell_idx < self.grid.head.len() {
329 self.curr_item_idx = self.grid.head[cell_idx];
330 }
331 }
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use crate::model::types::Point;
339
340 #[test]
341 fn grid_bins_points_correctly() {
342 let points = vec![
343 (Point::new(0.5, 0.5, 0.5), 1),
344 (Point::new(1.5, 0.5, 0.5), 2),
345 (Point::new(0.5, 1.5, 0.5), 3),
346 ];
347
348 let grid = Grid::new(points, 1.0);
349
350 assert_eq!(grid.dims, Vector3::new(2, 2, 1));
351
352 let center = Point::new(0.5, 0.5, 0.5);
353 let neighbors: Vec<_> = grid.neighbors(¢er, 0.1).collect();
354 assert!(neighbors.contains(&&1));
355 assert!(!neighbors.contains(&&2));
356 }
357
358 #[test]
359 fn grid_neighbors_returns_nearby_items() {
360 let points = vec![
361 (Point::new(0.0, 0.0, 0.0), "A"),
362 (Point::new(10.0, 0.0, 0.0), "B"),
363 ];
364 let grid = Grid::new(points, 2.0);
365
366 let center = Point::new(0.1, 0.1, 0.1);
367 let neighbors: Vec<_> = grid.neighbors(¢er, 1.0).collect();
368 assert_eq!(neighbors.len(), 1);
369 assert_eq!(*neighbors[0], "A");
370 }
371
372 #[test]
373 fn grid_handles_empty_input() {
374 let points: Vec<(Point, i32)> = vec![];
375 let grid = Grid::new(points, 1.0);
376 assert_eq!(grid.items.len(), 0);
377 assert_eq!(grid.neighbors(&Point::origin(), 1.0).count(), 0);
378 }
379
380 #[test]
381 fn grid_handles_dense_packing() {
382 let mut points = Vec::new();
383 for i in 0..100 {
384 points.push((Point::new(0.1, 0.1, 0.1), i));
385 }
386 let grid = Grid::new(points, 1.0);
387
388 let center = Point::new(0.1, 0.1, 0.1);
389 let count = grid.neighbors(¢er, 0.5).count();
390 assert_eq!(count, 100);
391 }
392
393 #[test]
394 fn grid_handles_boundary_conditions() {
395 let points = vec![
396 (Point::new(0.0, 0.0, 0.0), 1),
397 (Point::new(10.0, 10.0, 10.0), 2),
398 ];
399 let grid = Grid::new(points, 1.0);
400
401 let center = Point::new(0.0, 0.0, 0.0);
402 assert!(grid.has_neighbor(¢er, 0.1, |&i| i == 1));
403
404 let center = Point::new(10.0, 10.0, 10.0);
405 assert!(grid.has_neighbor(¢er, 0.1, |&i| i == 2));
406 }
407
408 #[test]
409 fn grid_exact_filtering_works() {
410 let points = vec![
411 (Point::new(0.0, 0.0, 0.0), "Center"),
412 (Point::new(0.9, 0.0, 0.0), "Inside"),
413 (Point::new(1.1, 0.0, 0.0), "Outside"),
414 ];
415 let grid = Grid::new(points, 2.0);
416
417 let center = Point::new(0.0, 0.0, 0.0);
418 let radius = 1.0;
419
420 let coarse_count = grid.neighbors(¢er, radius).count();
421 assert_eq!(coarse_count, 3);
422
423 let exact_neighbors: Vec<_> = grid.neighbors(¢er, radius).exact().collect();
424 assert_eq!(exact_neighbors.len(), 2);
425
426 let contains_item = |name: &str| exact_neighbors.iter().any(|(_, item)| **item == name);
427
428 assert!(contains_item("Center"));
429 assert!(contains_item("Inside"));
430 assert!(!contains_item("Outside"));
431 }
432
433 #[test]
434 fn grid_exact_filtering_handles_empty_grid() {
435 let points: Vec<(Point, i32)> = vec![];
436 let grid = Grid::new(points, 1.0);
437
438 let count = grid.neighbors(&Point::origin(), 1.0).exact().count();
439 assert_eq!(count, 0);
440 }
441
442 #[test]
443 fn grid_exact_filtering_handles_boundary_points() {
444 let points = vec![(Point::new(1.0, 0.0, 0.0), "OnBoundary")];
445 let grid = Grid::new(points, 2.0);
446
447 let center = Point::new(0.0, 0.0, 0.0);
448 let count = grid.neighbors(¢er, 1.0).exact().count();
449 assert_eq!(count, 1);
450
451 let count = grid.neighbors(¢er, 0.99).exact().count();
452 assert_eq!(count, 0);
453 }
454}