Skip to main content

bonsai/backends/
grid.rs

1//! Grid spatial index backend.
2//!
3//! Implements the [`SpatialBackend`] trait using a uniform spatial hash grid
4//! with D-dimensional cell coordinates addressed by `[i32; D]`.
5
6use std::collections::HashMap;
7
8use crate::backends::SpatialBackend;
9use crate::types::{BBox, BackendKind, CoordType, EntryId, Point};
10
11/// D-dimensional integer cell coordinate used as a `HashMap` key.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13struct GridCoord<const D: usize>([i32; D]);
14
15/// Uniform spatial hash grid index.
16///
17/// Points are bucketed into axis-aligned cells of size `cell_size[d]` along
18/// each dimension. The cell coordinate for a point `p` along axis `d` is
19/// `floor((p[d] - origin[d]) / cell_size[d])`.
20///
21/// # Type Parameters
22/// - `T`: payload type
23/// - `C`: coordinate scalar type (must implement [`CoordType`])
24/// - `D`: number of spatial dimensions
25pub struct GridIndex<T, C, const D: usize> {
26    cells: HashMap<GridCoord<D>, Vec<(Point<C, D>, T, EntryId)>>,
27    cell_size: [C; D],
28    origin: Point<C, D>,
29    len: usize,
30    next_id: u64,
31    id_to_cell: HashMap<u64, GridCoord<D>>,
32}
33
34fn cell_coord<C: CoordType, const D: usize>(
35    point: &Point<C, D>,
36    origin: &Point<C, D>,
37    cell_size: &[C; D],
38) -> GridCoord<D> {
39    let mut coord = [0i32; D];
40    for d in 0..D {
41        let p: f64 = point.coords()[d].into();
42        let o: f64 = origin.coords()[d].into();
43        let s: f64 = cell_size[d].into();
44        coord[d] = if s == 0.0 {
45            0
46        } else {
47            ((p - o) / s).floor() as i32
48        };
49    }
50    GridCoord(coord)
51}
52
53fn cell_range<C: CoordType, const D: usize>(
54    bbox: &BBox<C, D>,
55    origin: &Point<C, D>,
56    cell_size: &[C; D],
57) -> ([i32; D], [i32; D]) {
58    let mut min_coord = [0i32; D];
59    let mut max_coord = [0i32; D];
60    for d in 0..D {
61        let lo: f64 = bbox.min.coords()[d].into();
62        let hi: f64 = bbox.max.coords()[d].into();
63        let o: f64 = origin.coords()[d].into();
64        let s: f64 = cell_size[d].into();
65        if s == 0.0 {
66            min_coord[d] = 0;
67            max_coord[d] = 0;
68        } else {
69            min_coord[d] = ((lo - o) / s).floor() as i32;
70            max_coord[d] = ((hi - o) / s).floor() as i32;
71        }
72    }
73    (min_coord, max_coord)
74}
75
76/// Iterate over every `[i32; D]` coordinate in `[min, max]` (inclusive) and
77/// call `f` for each. The last axis varies fastest.
78fn for_each_cell_in_range<const D: usize, F: FnMut(GridCoord<D>)>(
79    min: &[i32; D],
80    max: &[i32; D],
81    f: &mut F,
82) {
83    let mut current = *min;
84    loop {
85        f(GridCoord(current));
86
87        let mut carry = true;
88        for d in (0..D).rev() {
89            if carry {
90                if current[d] < max[d] {
91                    current[d] += 1;
92                    carry = false;
93                } else {
94                    current[d] = min[d];
95                }
96            }
97        }
98        if carry {
99            break;
100        }
101    }
102}
103
104impl<T, C: CoordType, const D: usize> GridIndex<T, C, D> {
105    /// Create an empty grid with the given cell size and origin.
106    pub fn new(cell_size: [C; D], origin: Point<C, D>) -> Self {
107        Self {
108            cells: HashMap::new(),
109            cell_size,
110            origin,
111            len: 0,
112            next_id: 0,
113            id_to_cell: HashMap::new(),
114        }
115    }
116
117    fn alloc_id(&mut self) -> EntryId {
118        let id = EntryId(self.next_id);
119        self.next_id += 1;
120        id
121    }
122
123    /// Choose a cell size targeting approximately one point per cell for `n`
124    /// uniformly distributed points over `bbox`.
125    ///
126    /// Formula: `cell_size[d] = bbox_span[d] / n^(1/D)`
127    fn default_cell_size(bbox: &BBox<C, D>, n: usize) -> [C; D] {
128        let n_f = (n.max(1) as f64).powf(1.0 / D as f64);
129        let mut cs = [C::zero(); D];
130        for (d, c) in cs.iter_mut().enumerate().take(D) {
131            let span: f64 = (bbox.max.coords()[d] - bbox.min.coords()[d]).into();
132            let s = (span / n_f).max(1.0);
133            *c = C::from(s as f32);
134        }
135        cs
136    }
137
138    /// Return the number of occupied cells.
139    pub fn cell_count(&self) -> usize {
140        self.cells.len()
141    }
142
143    pub fn insert_entry(&mut self, point: Point<C, D>, payload: T) -> EntryId {
144        let id = self.alloc_id();
145        let coord = cell_coord(&point, &self.origin, &self.cell_size);
146        self.id_to_cell.insert(id.0, coord);
147        self.cells
148            .entry(coord)
149            .or_default()
150            .push((point, payload, id));
151        self.len += 1;
152        id
153    }
154
155    pub fn remove_entry(&mut self, id: EntryId) -> Option<T> {
156        let coord = self.id_to_cell.remove(&id.0)?;
157        let cell = self.cells.get_mut(&coord)?;
158        let pos = cell.iter().position(|(_, _, eid)| *eid == id)?;
159        let (_, payload, _) = cell.swap_remove(pos);
160        if cell.is_empty() {
161            self.cells.remove(&coord);
162        }
163        self.len -= 1;
164        Some(payload)
165    }
166
167    pub fn range_query_impl<'a>(&'a self, bbox: &BBox<C, D>) -> Vec<(EntryId, &'a T)> {
168        let (min_coord, max_coord) = cell_range(bbox, &self.origin, &self.cell_size);
169        let mut out = Vec::new();
170        for_each_cell_in_range(&min_coord, &max_coord, &mut |coord| {
171            if let Some(cell) = self.cells.get(&coord) {
172                for (point, payload, id) in cell {
173                    if bbox.contains_point(point) {
174                        out.push((*id, payload));
175                    }
176                }
177            }
178        });
179        out
180    }
181
182    pub fn knn_query_impl<'a>(
183        &'a self,
184        point: &Point<C, D>,
185        k: usize,
186    ) -> Vec<(f64, EntryId, &'a T)> {
187        if k == 0 {
188            return Vec::new();
189        }
190        let mut all: Vec<(f64, EntryId, &'a T)> = self
191            .cells
192            .values()
193            .flat_map(|cell| cell.iter())
194            .map(|(p, payload, id)| (point_dist_sq(p, point).sqrt(), *id, payload))
195            .collect();
196        all.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
197        all.truncate(k);
198        all
199    }
200
201    fn collect_all(&self) -> Vec<(Point<C, D>, EntryId, &T)> {
202        self.cells
203            .values()
204            .flat_map(|cell| cell.iter())
205            .map(|(p, payload, id)| (*p, *id, payload))
206            .collect()
207    }
208}
209
210fn point_dist_sq<C: CoordType, const D: usize>(a: &Point<C, D>, b: &Point<C, D>) -> f64 {
211    let mut sum = 0.0_f64;
212    for d in 0..D {
213        let da: f64 = a.coords()[d].into();
214        let db: f64 = b.coords()[d].into();
215        let diff = da - db;
216        sum += diff * diff;
217    }
218    sum
219}
220
221impl<T, C: CoordType, const D: usize> Default for GridIndex<T, C, D> {
222    fn default() -> Self {
223        Self::new([C::from(1.0_f32); D], Point::new([C::zero(); D]))
224    }
225}
226
227impl<T: Send + Sync + 'static, C: CoordType, const D: usize> SpatialBackend<T, C, D>
228    for GridIndex<T, C, D>
229{
230    fn insert(&mut self, point: Point<C, D>, payload: T) -> EntryId {
231        self.insert_entry(point, payload)
232    }
233
234    fn remove(&mut self, id: EntryId) -> Option<T> {
235        self.remove_entry(id)
236    }
237
238    fn range_query(&self, bbox: &BBox<C, D>) -> Vec<(EntryId, &T)> {
239        self.range_query_impl(bbox)
240    }
241
242    fn knn_query(&self, point: &Point<C, D>, k: usize) -> Vec<(f64, EntryId, &T)> {
243        self.knn_query_impl(point, k)
244    }
245
246    fn spatial_join(&self, other: &dyn SpatialBackend<T, C, D>) -> Vec<(EntryId, EntryId)> {
247        let self_entries = self.collect_all();
248        let other_entries = other.all_entries();
249        let mut pairs = Vec::new();
250        for (pa, id_a, _) in &self_entries {
251            let bbox_a = BBox::new(*pa, *pa);
252            for (pb, id_b, _) in &other_entries {
253                if bbox_a.intersects(&BBox::new(*pb, *pb)) {
254                    pairs.push((*id_a, *id_b));
255                }
256            }
257        }
258        pairs
259    }
260
261    fn bulk_load(entries: Vec<(Point<C, D>, T)>) -> Self {
262        if entries.is_empty() {
263            return Self::default();
264        }
265        let mut min_c = *entries[0].0.coords();
266        let mut max_c = *entries[0].0.coords();
267        for (p, _) in &entries {
268            for d in 0..D {
269                let v = p.coords()[d];
270                if v < min_c[d] {
271                    min_c[d] = v;
272                }
273                if v > max_c[d] {
274                    max_c[d] = v;
275                }
276            }
277        }
278        let bbox = BBox::new(Point::new(min_c), Point::new(max_c));
279        let cell_size = Self::default_cell_size(&bbox, entries.len());
280        let origin = Point::new(min_c);
281        let mut grid = Self::new(cell_size, origin);
282        for (point, payload) in entries {
283            grid.insert_entry(point, payload);
284        }
285        grid
286    }
287
288    fn len(&self) -> usize {
289        self.len
290    }
291
292    fn kind(&self) -> BackendKind {
293        BackendKind::Grid
294    }
295
296    fn all_entries(&self) -> Vec<(Point<C, D>, EntryId, &T)> {
297        self.collect_all()
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use proptest::prelude::*;
305
306    struct Lcg(u64);
307    impl Lcg {
308        fn new(seed: u64) -> Self {
309            Self(seed)
310        }
311        fn next_f64(&mut self) -> f64 {
312            self.0 = self
313                .0
314                .wrapping_mul(6_364_136_223_846_793_005)
315                .wrapping_add(1_442_695_040_888_963_407);
316            (self.0 >> 11) as f64 / (1u64 << 53) as f64
317        }
318    }
319
320    fn brute_range<C: CoordType, const D: usize>(
321        pts: &[(Point<C, D>, EntryId)],
322        bbox: &BBox<C, D>,
323    ) -> Vec<EntryId> {
324        let mut ids: Vec<EntryId> = pts
325            .iter()
326            .filter(|(p, _)| bbox.contains_point(p))
327            .map(|(_, id)| *id)
328            .collect();
329        ids.sort_by_key(|id| id.0);
330        ids
331    }
332
333    #[test]
334    fn insert_and_len() {
335        let mut grid = GridIndex::<u32, f64, 2>::default();
336        assert_eq!(grid.len(), 0);
337        grid.insert(Point::new([0.5, 0.5]), 1u32);
338        assert_eq!(grid.len(), 1);
339        grid.insert(Point::new([1.5, 1.5]), 2u32);
340        assert_eq!(grid.len(), 2);
341    }
342
343    #[test]
344    fn range_query_basic() {
345        let mut grid = GridIndex::<u32, f64, 2>::default();
346        let id1 = grid.insert(Point::new([0.5, 0.5]), 1u32);
347        let id2 = grid.insert(Point::new([1.5, 1.5]), 2u32);
348        let _id3 = grid.insert(Point::new([5.0, 5.0]), 3u32);
349        let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([2.0, 2.0]));
350        let mut got: Vec<EntryId> = grid
351            .range_query(&bbox)
352            .into_iter()
353            .map(|(id, _)| id)
354            .collect();
355        got.sort_by_key(|id| id.0);
356        assert_eq!(got, vec![id1, id2]);
357    }
358
359    #[test]
360    fn remove_works() {
361        let mut grid = GridIndex::<u32, f64, 2>::default();
362        let id1 = grid.insert(Point::new([1.0, 1.0]), 10u32);
363        let id2 = grid.insert(Point::new([2.0, 2.0]), 20u32);
364        assert_eq!(grid.len(), 2);
365        assert_eq!(grid.remove(id1), Some(10u32));
366        assert_eq!(grid.len(), 1);
367        let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([3.0, 3.0]));
368        let ids: Vec<EntryId> = grid
369            .range_query(&bbox)
370            .into_iter()
371            .map(|(id, _)| id)
372            .collect();
373        assert!(!ids.contains(&id1));
374        assert!(ids.contains(&id2));
375    }
376
377    #[test]
378    fn kind_is_grid() {
379        assert_eq!(
380            GridIndex::<u32, f64, 2>::default().kind(),
381            BackendKind::Grid
382        );
383    }
384
385    #[test]
386    fn cell_coord_d2() {
387        let origin = Point::new([0.0_f64, 0.0]);
388        let cell_size = [1.0_f64, 1.0];
389        let point = Point::new([2.5_f64, 3.7]);
390        let coord = cell_coord(&point, &origin, &cell_size);
391        assert_eq!(coord.0, [2, 3]);
392    }
393
394    #[test]
395    fn cell_coord_d3() {
396        let origin = Point::new([0.0_f64; 3]);
397        let cell_size = [2.0_f64; 3];
398        let point = Point::new([5.0_f64, 7.0, 9.0]);
399        let coord = cell_coord(&point, &origin, &cell_size);
400        assert_eq!(coord.0, [2, 3, 4]);
401    }
402
403    #[test]
404    fn cell_coord_d4() {
405        let origin = Point::new([0.0_f64; 4]);
406        let cell_size = [10.0_f64; 4];
407        let point = Point::new([15.0_f64, 25.0, 35.0, 45.0]);
408        let coord = cell_coord(&point, &origin, &cell_size);
409        assert_eq!(coord.0, [1, 2, 3, 4]);
410    }
411
412    #[test]
413    fn cell_coord_d5() {
414        let origin = Point::new([0.0_f64; 5]);
415        let cell_size = [5.0_f64; 5];
416        let point = Point::new([0.0_f64, 5.0, 10.0, 15.0, 20.0]);
417        let coord = cell_coord(&point, &origin, &cell_size);
418        assert_eq!(coord.0, [0, 1, 2, 3, 4]);
419    }
420
421    #[test]
422    fn cell_coord_d6() {
423        let origin = Point::new([0.0_f64; 6]);
424        let cell_size = [3.0_f64; 6];
425        let point = Point::new([3.0_f64, 6.0, 9.0, 12.0, 15.0, 18.0]);
426        let coord = cell_coord(&point, &origin, &cell_size);
427        assert_eq!(coord.0, [1, 2, 3, 4, 5, 6]);
428    }
429
430    #[test]
431    fn cell_coord_negative() {
432        let origin = Point::new([0.0_f64, 0.0]);
433        let cell_size = [1.0_f64, 1.0];
434        let point = Point::new([-1.5_f64, -0.5]);
435        let coord = cell_coord(&point, &origin, &cell_size);
436        assert_eq!(coord.0, [-2, -1]);
437    }
438
439    #[test]
440    fn uniform_data_approx_one_point_per_cell() {
441        let n = 10_000usize;
442        let mut rng = Lcg::new(42);
443        let entries: Vec<(Point<f64, 2>, usize)> = (0..n)
444            .map(|i| {
445                (
446                    Point::new([rng.next_f64() * 1000.0, rng.next_f64() * 1000.0]),
447                    i,
448                )
449            })
450            .collect();
451        let grid = GridIndex::<usize, f64, 2>::bulk_load(entries);
452        let cell_count = grid.cells.len();
453        let avg = n as f64 / cell_count as f64;
454        assert!(
455            (0.5..=4.0).contains(&avg),
456            "avg points/cell = {avg:.2}, cell_count = {cell_count}"
457        );
458    }
459
460    #[test]
461    fn range_query_vs_brute_force_2d_10k() {
462        let n = 10_000usize;
463        let mut rng = Lcg::new(99);
464        let mut grid = GridIndex::<usize, f64, 2>::new([10.0_f64, 10.0], Point::new([0.0, 0.0]));
465        let mut pt_ids = Vec::new();
466        for i in 0..n {
467            let p = Point::new([rng.next_f64() * 1000.0, rng.next_f64() * 1000.0]);
468            let id = grid.insert(p, i);
469            pt_ids.push((p, id));
470        }
471        let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([500.0, 500.0]));
472        let mut got: Vec<EntryId> = grid
473            .range_query(&bbox)
474            .into_iter()
475            .map(|(id, _)| id)
476            .collect();
477        got.sort_by_key(|id| id.0);
478        let expected = brute_range(&pt_ids, &bbox);
479        assert_eq!(got, expected, "2D 10k range query mismatch");
480    }
481
482    fn pt2d() -> impl Strategy<Value = Point<f64, 2>> {
483        (0.0_f64..1000.0, 0.0_f64..1000.0).prop_map(|(x, y)| Point::new([x, y]))
484    }
485
486    fn bbox2d() -> impl Strategy<Value = BBox<f64, 2>> {
487        (
488            0.0_f64..900.0,
489            0.0_f64..900.0,
490            10.0_f64..200.0,
491            10.0_f64..200.0,
492        )
493            .prop_map(|(x, y, w, h)| BBox::new(Point::new([x, y]), Point::new([x + w, y + h])))
494    }
495
496    // Insert-Remove Round Trip
497    proptest! {
498        #![proptest_config(proptest::test_runner::Config {
499            cases: 100,
500            ..Default::default()
501        })]
502
503        #[test]
504        fn prop_insert_remove_round_trip_grid(
505            pts in prop::collection::vec(pt2d(), 1..50),
506            remove_indices in prop::collection::vec(0usize..50, 0..25),
507        ) {
508            let mut grid = GridIndex::<usize, f64, 2>::new(
509                [10.0_f64, 10.0],
510                Point::new([0.0_f64, 0.0]),
511            );
512            let mut inserted: Vec<(Point<f64, 2>, EntryId)> = Vec::new();
513            for (i, &p) in pts.iter().enumerate() {
514                let id = grid.insert(p, i);
515                inserted.push((p, id));
516            }
517            let mut removed_ids: Vec<EntryId> = Vec::new();
518            for &ri in &remove_indices {
519                let idx = ri % inserted.len();
520                let (_, id) = inserted[idx];
521                if !removed_ids.contains(&id) {
522                    let result = grid.remove(id);
523                    prop_assert!(result.is_some(), "remove returned None for inserted id");
524                    removed_ids.push(id);
525                }
526            }
527            // Use a bbox that covers the point space (0..1000) without being huge
528            let full_bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([1000.0, 1000.0]));
529            let remaining_ids: Vec<EntryId> = grid.range_query(&full_bbox)
530                .into_iter()
531                .map(|(id, _)| id)
532                .collect();
533            for &removed_id in &removed_ids {
534                prop_assert!(
535                    !remaining_ids.contains(&removed_id),
536                    "removed entry {:?} still appears in range query",
537                    removed_id
538                );
539            }
540            let expected_len = inserted.len() - removed_ids.len();
541            prop_assert_eq!(grid.len(), expected_len);
542        }
543    }
544
545    // brute-force linear scan for any random dataset and bbox.
546    proptest! {
547        #![proptest_config(proptest::test_runner::Config {
548            cases: 200,
549            ..Default::default()
550        })]
551
552        #[test]
553        fn prop_range_query_oracle_grid(
554            pts in prop::collection::vec(pt2d(), 1..100),
555            bbox in bbox2d(),
556        ) {
557            let mut grid = GridIndex::<usize, f64, 2>::new(
558                [10.0_f64, 10.0],
559                Point::new([0.0_f64, 0.0]),
560            );
561            let mut pt_ids: Vec<(Point<f64, 2>, EntryId)> = Vec::new();
562            for (i, p) in pts.iter().enumerate() {
563                let id = grid.insert(*p, i);
564                pt_ids.push((*p, id));
565            }
566            let mut got: Vec<EntryId> =
567                grid.range_query(&bbox).into_iter().map(|(id, _)| id).collect();
568            got.sort_by_key(|id| id.0);
569            let expected = brute_range(&pt_ids, &bbox);
570            prop_assert_eq!(got, expected);
571        }
572    }
573}