bio_forge/model/
grid.rs

1//! Spatial indexing primitives for accelerating geometric queries.
2//!
3//! This module provides a [`Grid`] structure that partitions 3D space into uniform cells,
4//! enabling **O(1)** average-case lookups for neighbor searches, collision detection, and
5//! range queries.
6
7use super::types::Point;
8use nalgebra::Vector3;
9
10/// Sentinel value indicating the end of a linked list.
11const SENTINEL: u32 = u32::MAX;
12
13/// A uniform spatial grid that bins items into cubic cells.
14///
15/// The grid is defined by a bounding box and a cell size. Items are mapped to cells
16/// based on their coordinates. This structure is optimized for "fixed-radius" queries,
17/// where the search radius is comparable to the cell size.
18///
19/// # Performance
20///
21/// - Construction: **O(N)** where N is the number of items.
22/// - Neighbor queries: **O(1)** average-case per query, assuming uniform distribution.
23#[derive(Debug, Clone)]
24pub struct Grid<T> {
25    /// Side length of each cubic cell.
26    cell_size: f64,
27    /// Minimum coordinate of the grid's bounding box.
28    origin: Point,
29    /// Number of cells along each dimension (x, y, z).
30    dims: Vector3<usize>,
31    /// Index of the first item in each cell. Size = num_cells.
32    head: Vec<u32>,
33    /// Index of the next item in the linked list. Size = num_items.
34    next: Vec<u32>,
35    /// Stored items with their positions. Size = num_items.
36    items: Vec<(Point, T)>,
37}
38
39impl<T> Grid<T> {
40    /// Creates a new grid enclosing the provided points.
41    ///
42    /// The grid dimensions are automatically calculated to encompass all points with
43    /// a small padding.
44    ///
45    /// # Arguments
46    ///
47    /// * `items` - Iterator yielding `(position, item)` pairs.
48    /// * `cell_size` - The side length of each spatial bin.
49    ///
50    /// # Panics
51    ///
52    /// Panics if `cell_size` is non-positive.
53    pub fn new(items: impl IntoIterator<Item = (Point, T)>, cell_size: f64) -> Self {
54        assert!(cell_size > 0.0, "Cell size must be positive");
55
56        let input_items: Vec<_> = items.into_iter().collect();
57        let num_items = input_items.len();
58
59        if num_items == 0 {
60            return Self {
61                cell_size,
62                origin: Point::origin(),
63                dims: Vector3::zeros(),
64                head: Vec::new(),
65                next: Vec::new(),
66                items: Vec::new(),
67            };
68        }
69
70        let mut min = Point::new(f64::MAX, f64::MAX, f64::MAX);
71        let mut max = Point::new(f64::MIN, f64::MIN, f64::MIN);
72
73        for (pos, _) in &input_items {
74            min = min.inf(pos);
75            max = max.sup(pos);
76        }
77
78        let epsilon = 1e-6;
79        max += Vector3::new(epsilon, epsilon, epsilon);
80
81        let extent = max - min;
82        let dims = Vector3::new(
83            (extent.x / cell_size).ceil() as usize,
84            (extent.y / cell_size).ceil() as usize,
85            (extent.z / cell_size).ceil() as usize,
86        );
87
88        let total_cells = dims.x * dims.y * dims.z;
89
90        let mut head = vec![SENTINEL; total_cells];
91        let mut next = vec![SENTINEL; num_items];
92        let mut stored_items = Vec::with_capacity(num_items);
93
94        for (i, (pos, item)) in input_items.into_iter().enumerate() {
95            stored_items.push((pos, item));
96
97            if let Some(cell_idx) = Self::get_cell_index_static(&pos, dims, min, cell_size) {
98                next[i] = head[cell_idx];
99                head[cell_idx] = i as u32;
100            }
101        }
102
103        Self {
104            cell_size,
105            origin: min,
106            dims,
107            head,
108            next,
109            items: stored_items,
110        }
111    }
112
113    /// Static helper to compute cell index without `self`.
114    fn get_cell_index_static(
115        pos: &Point,
116        dims: Vector3<usize>,
117        origin: Point,
118        cell_size: f64,
119    ) -> Option<usize> {
120        if pos.x < origin.x || pos.y < origin.y || pos.z < origin.z {
121            return None;
122        }
123
124        let offset = pos - origin;
125        let x = (offset.x / cell_size).floor() as usize;
126        let y = (offset.y / cell_size).floor() as usize;
127        let z = (offset.z / cell_size).floor() as usize;
128
129        if x >= dims.x || y >= dims.y || z >= dims.z {
130            return None;
131        }
132
133        Some(x + y * dims.x + z * dims.x * dims.y)
134    }
135
136    /// Iterates over all items in cells overlapping with the query sphere.
137    ///
138    /// The returned iterator yields items from candidate cells. To filter strictly by
139    /// Euclidean distance, use the `.exact()` method on the returned iterator.
140    ///
141    /// # Arguments
142    ///
143    /// * `center` - Center of the search sphere.
144    /// * `radius` - Radius of the search sphere.
145    pub fn neighbors<'a>(&'a self, center: &Point, radius: f64) -> GridNeighborhood<'a, T> {
146        if self.items.is_empty() {
147            return GridNeighborhood {
148                grid: self,
149                min_x: 0,
150                max_x: 0,
151                min_y: 0,
152                max_y: 0,
153                max_z: 0,
154                curr_x: 0,
155                curr_y: 0,
156                curr_z: 1,
157                curr_item_idx: SENTINEL,
158                center: *center,
159                radius_sq: radius * radius,
160            };
161        }
162
163        let min_idx = self.get_grid_coords(&(center - Vector3::new(radius, radius, radius)));
164        let max_idx = self.get_grid_coords(&(center + Vector3::new(radius, radius, radius)));
165
166        let (min_x, min_y, min_z) = min_idx;
167        let (max_x, max_y, max_z) = max_idx;
168
169        GridNeighborhood {
170            grid: self,
171            min_x,
172            max_x,
173            min_y,
174            max_y,
175            max_z,
176            curr_x: min_x,
177            curr_y: min_y,
178            curr_z: min_z,
179            curr_item_idx: SENTINEL,
180            center: *center,
181            radius_sq: radius * radius,
182        }
183    }
184
185    /// Helper to get clamped grid coordinates (x, y, z).
186    fn get_grid_coords(&self, pos: &Point) -> (usize, usize, usize) {
187        let offset = pos - self.origin;
188        let x = (offset.x / self.cell_size).floor() as isize;
189        let y = (offset.y / self.cell_size).floor() as isize;
190        let z = (offset.z / self.cell_size).floor() as isize;
191
192        (
193            x.clamp(0, (self.dims.x as isize) - 1) as usize,
194            y.clamp(0, (self.dims.y as isize) - 1) as usize,
195            z.clamp(0, (self.dims.z as isize) - 1) as usize,
196        )
197    }
198
199    /// Checks if any item in the grid is within `radius` of `point`.
200    ///
201    /// This is optimized to return early.
202    ///
203    /// # Arguments
204    ///
205    /// * `point` - The query point.
206    /// * `radius` - The cutoff distance.
207    /// * `predicate` - A closure to filter items (e.g., check exact distance).
208    pub fn has_neighbor<F>(&self, point: &Point, radius: f64, mut predicate: F) -> bool
209    where
210        F: FnMut(&T) -> bool,
211    {
212        for item in self.neighbors(point, radius) {
213            if predicate(item) {
214                return true;
215            }
216        }
217        false
218    }
219}
220
221/// Iterator for traversing grid cells and their linked lists.
222///
223/// This iterator yields all items in the cells that overlap with the query sphere.
224/// Use [`GridNeighborhood::exact`] to filter items strictly within the radius.
225pub struct GridNeighborhood<'a, T> {
226    grid: &'a Grid<T>,
227    min_x: usize,
228    max_x: usize,
229    min_y: usize,
230    max_y: usize,
231    max_z: usize,
232    curr_x: usize,
233    curr_y: usize,
234    curr_z: usize,
235    curr_item_idx: u32,
236    center: Point,
237    radius_sq: f64,
238}
239
240impl<'a, T> GridNeighborhood<'a, T> {
241    /// Returns an iterator that yields only items strictly within the search radius.
242    ///
243    /// This method uses the internally stored positions to perform the distance check,
244    /// avoiding the need for external lookups.
245    pub fn exact(self) -> impl Iterator<Item = (Point, &'a T)> + 'a {
246        ExactGridNeighborhood { inner: self }
247    }
248}
249
250/// Iterator that yields items strictly within the Euclidean radius.
251///
252/// This iterator filters candidates from [`GridNeighborhood`] using the stored positions
253/// and the query center/radius, returning only items whose distance to the center is
254/// less than or equal to the specified radius.
255pub struct ExactGridNeighborhood<'a, T> {
256    inner: GridNeighborhood<'a, T>,
257}
258
259impl<'a, T> Iterator for ExactGridNeighborhood<'a, T> {
260    type Item = (Point, &'a T);
261
262    fn next(&mut self) -> Option<Self::Item> {
263        loop {
264            if self.inner.curr_item_idx != SENTINEL {
265                let (pos, item) = &self.inner.grid.items[self.inner.curr_item_idx as usize];
266                self.inner.curr_item_idx = self.inner.grid.next[self.inner.curr_item_idx as usize];
267
268                if nalgebra::distance_squared(pos, &self.inner.center) <= self.inner.radius_sq {
269                    return Some((*pos, item));
270                }
271                continue;
272            }
273
274            if self.inner.curr_x > self.inner.max_x {
275                self.inner.curr_x = self.inner.min_x;
276                self.inner.curr_y += 1;
277            }
278            if self.inner.curr_y > self.inner.max_y {
279                self.inner.curr_y = self.inner.min_y;
280                self.inner.curr_z += 1;
281            }
282            if self.inner.curr_z > self.inner.max_z {
283                return None;
284            }
285
286            let cell_idx = self.inner.curr_x
287                + self.inner.curr_y * self.inner.grid.dims.x
288                + self.inner.curr_z * self.inner.grid.dims.x * self.inner.grid.dims.y;
289
290            self.inner.curr_x += 1;
291
292            if cell_idx < self.inner.grid.head.len() {
293                self.inner.curr_item_idx = self.inner.grid.head[cell_idx];
294            }
295        }
296    }
297}
298
299impl<'a, T> Iterator for GridNeighborhood<'a, T> {
300    type Item = &'a T;
301
302    fn next(&mut self) -> Option<Self::Item> {
303        loop {
304            if self.curr_item_idx != SENTINEL {
305                let (_, item) = &self.grid.items[self.curr_item_idx as usize];
306                self.curr_item_idx = self.grid.next[self.curr_item_idx as usize];
307                return Some(item);
308            }
309
310            if self.curr_x > self.max_x {
311                self.curr_x = self.min_x;
312                self.curr_y += 1;
313            }
314            if self.curr_y > self.max_y {
315                self.curr_y = self.min_y;
316                self.curr_z += 1;
317            }
318            if self.curr_z > self.max_z {
319                return None;
320            }
321
322            let cell_idx = self.curr_x
323                + self.curr_y * self.grid.dims.x
324                + self.curr_z * self.grid.dims.x * self.grid.dims.y;
325
326            self.curr_x += 1;
327
328            if cell_idx < self.grid.head.len() {
329                self.curr_item_idx = self.grid.head[cell_idx];
330            }
331        }
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use crate::model::types::Point;
339
340    #[test]
341    fn grid_bins_points_correctly() {
342        let points = vec![
343            (Point::new(0.5, 0.5, 0.5), 1),
344            (Point::new(1.5, 0.5, 0.5), 2),
345            (Point::new(0.5, 1.5, 0.5), 3),
346        ];
347
348        let grid = Grid::new(points, 1.0);
349
350        assert_eq!(grid.dims, Vector3::new(2, 2, 1));
351
352        let center = Point::new(0.5, 0.5, 0.5);
353        let neighbors: Vec<_> = grid.neighbors(&center, 0.1).collect();
354        assert!(neighbors.contains(&&1));
355        assert!(!neighbors.contains(&&2));
356    }
357
358    #[test]
359    fn grid_neighbors_returns_nearby_items() {
360        let points = vec![
361            (Point::new(0.0, 0.0, 0.0), "A"),
362            (Point::new(10.0, 0.0, 0.0), "B"),
363        ];
364        let grid = Grid::new(points, 2.0);
365
366        let center = Point::new(0.1, 0.1, 0.1);
367        let neighbors: Vec<_> = grid.neighbors(&center, 1.0).collect();
368        assert_eq!(neighbors.len(), 1);
369        assert_eq!(*neighbors[0], "A");
370    }
371
372    #[test]
373    fn grid_handles_empty_input() {
374        let points: Vec<(Point, i32)> = vec![];
375        let grid = Grid::new(points, 1.0);
376        assert_eq!(grid.items.len(), 0);
377        assert_eq!(grid.neighbors(&Point::origin(), 1.0).count(), 0);
378    }
379
380    #[test]
381    fn grid_handles_dense_packing() {
382        let mut points = Vec::new();
383        for i in 0..100 {
384            points.push((Point::new(0.1, 0.1, 0.1), i));
385        }
386        let grid = Grid::new(points, 1.0);
387
388        let center = Point::new(0.1, 0.1, 0.1);
389        let count = grid.neighbors(&center, 0.5).count();
390        assert_eq!(count, 100);
391    }
392
393    #[test]
394    fn grid_handles_boundary_conditions() {
395        let points = vec![
396            (Point::new(0.0, 0.0, 0.0), 1),
397            (Point::new(10.0, 10.0, 10.0), 2),
398        ];
399        let grid = Grid::new(points, 1.0);
400
401        let center = Point::new(0.0, 0.0, 0.0);
402        assert!(grid.has_neighbor(&center, 0.1, |&i| i == 1));
403
404        let center = Point::new(10.0, 10.0, 10.0);
405        assert!(grid.has_neighbor(&center, 0.1, |&i| i == 2));
406    }
407
408    #[test]
409    fn grid_exact_filtering_works() {
410        let points = vec![
411            (Point::new(0.0, 0.0, 0.0), "Center"),
412            (Point::new(0.9, 0.0, 0.0), "Inside"),
413            (Point::new(1.1, 0.0, 0.0), "Outside"),
414        ];
415        let grid = Grid::new(points, 2.0);
416
417        let center = Point::new(0.0, 0.0, 0.0);
418        let radius = 1.0;
419
420        let coarse_count = grid.neighbors(&center, radius).count();
421        assert_eq!(coarse_count, 3);
422
423        let exact_neighbors: Vec<_> = grid.neighbors(&center, radius).exact().collect();
424        assert_eq!(exact_neighbors.len(), 2);
425
426        let contains_item = |name: &str| exact_neighbors.iter().any(|(_, item)| **item == name);
427
428        assert!(contains_item("Center"));
429        assert!(contains_item("Inside"));
430        assert!(!contains_item("Outside"));
431    }
432
433    #[test]
434    fn grid_exact_filtering_handles_empty_grid() {
435        let points: Vec<(Point, i32)> = vec![];
436        let grid = Grid::new(points, 1.0);
437
438        let count = grid.neighbors(&Point::origin(), 1.0).exact().count();
439        assert_eq!(count, 0);
440    }
441
442    #[test]
443    fn grid_exact_filtering_handles_boundary_points() {
444        let points = vec![(Point::new(1.0, 0.0, 0.0), "OnBoundary")];
445        let grid = Grid::new(points, 2.0);
446
447        let center = Point::new(0.0, 0.0, 0.0);
448        let count = grid.neighbors(&center, 1.0).exact().count();
449        assert_eq!(count, 1);
450
451        let count = grid.neighbors(&center, 0.99).exact().count();
452        assert_eq!(count, 0);
453    }
454}