1use std::collections::HashMap;
7
8use crate::backends::SpatialBackend;
9use crate::types::{BBox, BackendKind, CoordType, EntryId, Point};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13struct GridCoord<const D: usize>([i32; D]);
14
15pub struct GridIndex<T, C, const D: usize> {
26 cells: HashMap<GridCoord<D>, Vec<(Point<C, D>, T, EntryId)>>,
27 cell_size: [C; D],
28 origin: Point<C, D>,
29 len: usize,
30 next_id: u64,
31 id_to_cell: HashMap<u64, GridCoord<D>>,
32}
33
34fn cell_coord<C: CoordType, const D: usize>(
35 point: &Point<C, D>,
36 origin: &Point<C, D>,
37 cell_size: &[C; D],
38) -> GridCoord<D> {
39 let mut coord = [0i32; D];
40 for d in 0..D {
41 let p: f64 = point.coords()[d].into();
42 let o: f64 = origin.coords()[d].into();
43 let s: f64 = cell_size[d].into();
44 coord[d] = if s == 0.0 {
45 0
46 } else {
47 ((p - o) / s).floor() as i32
48 };
49 }
50 GridCoord(coord)
51}
52
53fn cell_range<C: CoordType, const D: usize>(
54 bbox: &BBox<C, D>,
55 origin: &Point<C, D>,
56 cell_size: &[C; D],
57) -> ([i32; D], [i32; D]) {
58 let mut min_coord = [0i32; D];
59 let mut max_coord = [0i32; D];
60 for d in 0..D {
61 let lo: f64 = bbox.min.coords()[d].into();
62 let hi: f64 = bbox.max.coords()[d].into();
63 let o: f64 = origin.coords()[d].into();
64 let s: f64 = cell_size[d].into();
65 if s == 0.0 {
66 min_coord[d] = 0;
67 max_coord[d] = 0;
68 } else {
69 min_coord[d] = ((lo - o) / s).floor() as i32;
70 max_coord[d] = ((hi - o) / s).floor() as i32;
71 }
72 }
73 (min_coord, max_coord)
74}
75
76fn for_each_cell_in_range<const D: usize, F: FnMut(GridCoord<D>)>(
79 min: &[i32; D],
80 max: &[i32; D],
81 f: &mut F,
82) {
83 let mut current = *min;
84 loop {
85 f(GridCoord(current));
86
87 let mut carry = true;
88 for d in (0..D).rev() {
89 if carry {
90 if current[d] < max[d] {
91 current[d] += 1;
92 carry = false;
93 } else {
94 current[d] = min[d];
95 }
96 }
97 }
98 if carry {
99 break;
100 }
101 }
102}
103
104impl<T, C: CoordType, const D: usize> GridIndex<T, C, D> {
105 pub fn new(cell_size: [C; D], origin: Point<C, D>) -> Self {
107 Self {
108 cells: HashMap::new(),
109 cell_size,
110 origin,
111 len: 0,
112 next_id: 0,
113 id_to_cell: HashMap::new(),
114 }
115 }
116
117 fn alloc_id(&mut self) -> EntryId {
118 let id = EntryId(self.next_id);
119 self.next_id += 1;
120 id
121 }
122
123 fn default_cell_size(bbox: &BBox<C, D>, n: usize) -> [C; D] {
128 let n_f = (n.max(1) as f64).powf(1.0 / D as f64);
129 let mut cs = [C::zero(); D];
130 for (d, c) in cs.iter_mut().enumerate().take(D) {
131 let span: f64 = (bbox.max.coords()[d] - bbox.min.coords()[d]).into();
132 let s = (span / n_f).max(1.0);
133 *c = C::from(s as f32);
134 }
135 cs
136 }
137
138 pub fn cell_count(&self) -> usize {
140 self.cells.len()
141 }
142
143 pub fn insert_entry(&mut self, point: Point<C, D>, payload: T) -> EntryId {
144 let id = self.alloc_id();
145 let coord = cell_coord(&point, &self.origin, &self.cell_size);
146 self.id_to_cell.insert(id.0, coord);
147 self.cells
148 .entry(coord)
149 .or_default()
150 .push((point, payload, id));
151 self.len += 1;
152 id
153 }
154
155 pub fn remove_entry(&mut self, id: EntryId) -> Option<T> {
156 let coord = self.id_to_cell.remove(&id.0)?;
157 let cell = self.cells.get_mut(&coord)?;
158 let pos = cell.iter().position(|(_, _, eid)| *eid == id)?;
159 let (_, payload, _) = cell.swap_remove(pos);
160 if cell.is_empty() {
161 self.cells.remove(&coord);
162 }
163 self.len -= 1;
164 Some(payload)
165 }
166
167 pub fn range_query_impl<'a>(&'a self, bbox: &BBox<C, D>) -> Vec<(EntryId, &'a T)> {
168 let (min_coord, max_coord) = cell_range(bbox, &self.origin, &self.cell_size);
169 let mut out = Vec::new();
170 for_each_cell_in_range(&min_coord, &max_coord, &mut |coord| {
171 if let Some(cell) = self.cells.get(&coord) {
172 for (point, payload, id) in cell {
173 if bbox.contains_point(point) {
174 out.push((*id, payload));
175 }
176 }
177 }
178 });
179 out
180 }
181
182 pub fn knn_query_impl<'a>(
183 &'a self,
184 point: &Point<C, D>,
185 k: usize,
186 ) -> Vec<(f64, EntryId, &'a T)> {
187 if k == 0 {
188 return Vec::new();
189 }
190 let mut all: Vec<(f64, EntryId, &'a T)> = self
191 .cells
192 .values()
193 .flat_map(|cell| cell.iter())
194 .map(|(p, payload, id)| (point_dist_sq(p, point).sqrt(), *id, payload))
195 .collect();
196 all.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
197 all.truncate(k);
198 all
199 }
200
201 fn collect_all(&self) -> Vec<(Point<C, D>, EntryId, &T)> {
202 self.cells
203 .values()
204 .flat_map(|cell| cell.iter())
205 .map(|(p, payload, id)| (*p, *id, payload))
206 .collect()
207 }
208}
209
210fn point_dist_sq<C: CoordType, const D: usize>(a: &Point<C, D>, b: &Point<C, D>) -> f64 {
211 let mut sum = 0.0_f64;
212 for d in 0..D {
213 let da: f64 = a.coords()[d].into();
214 let db: f64 = b.coords()[d].into();
215 let diff = da - db;
216 sum += diff * diff;
217 }
218 sum
219}
220
221impl<T, C: CoordType, const D: usize> Default for GridIndex<T, C, D> {
222 fn default() -> Self {
223 Self::new([C::from(1.0_f32); D], Point::new([C::zero(); D]))
224 }
225}
226
227impl<T: Send + Sync + 'static, C: CoordType, const D: usize> SpatialBackend<T, C, D>
228 for GridIndex<T, C, D>
229{
230 fn insert(&mut self, point: Point<C, D>, payload: T) -> EntryId {
231 self.insert_entry(point, payload)
232 }
233
234 fn remove(&mut self, id: EntryId) -> Option<T> {
235 self.remove_entry(id)
236 }
237
238 fn range_query(&self, bbox: &BBox<C, D>) -> Vec<(EntryId, &T)> {
239 self.range_query_impl(bbox)
240 }
241
242 fn knn_query(&self, point: &Point<C, D>, k: usize) -> Vec<(f64, EntryId, &T)> {
243 self.knn_query_impl(point, k)
244 }
245
246 fn spatial_join(&self, other: &dyn SpatialBackend<T, C, D>) -> Vec<(EntryId, EntryId)> {
247 let self_entries = self.collect_all();
248 let other_entries = other.all_entries();
249 let mut pairs = Vec::new();
250 for (pa, id_a, _) in &self_entries {
251 let bbox_a = BBox::new(*pa, *pa);
252 for (pb, id_b, _) in &other_entries {
253 if bbox_a.intersects(&BBox::new(*pb, *pb)) {
254 pairs.push((*id_a, *id_b));
255 }
256 }
257 }
258 pairs
259 }
260
261 fn bulk_load(entries: Vec<(Point<C, D>, T)>) -> Self {
262 if entries.is_empty() {
263 return Self::default();
264 }
265 let mut min_c = *entries[0].0.coords();
266 let mut max_c = *entries[0].0.coords();
267 for (p, _) in &entries {
268 for d in 0..D {
269 let v = p.coords()[d];
270 if v < min_c[d] {
271 min_c[d] = v;
272 }
273 if v > max_c[d] {
274 max_c[d] = v;
275 }
276 }
277 }
278 let bbox = BBox::new(Point::new(min_c), Point::new(max_c));
279 let cell_size = Self::default_cell_size(&bbox, entries.len());
280 let origin = Point::new(min_c);
281 let mut grid = Self::new(cell_size, origin);
282 for (point, payload) in entries {
283 grid.insert_entry(point, payload);
284 }
285 grid
286 }
287
288 fn len(&self) -> usize {
289 self.len
290 }
291
292 fn kind(&self) -> BackendKind {
293 BackendKind::Grid
294 }
295
296 fn all_entries(&self) -> Vec<(Point<C, D>, EntryId, &T)> {
297 self.collect_all()
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use proptest::prelude::*;
305
306 struct Lcg(u64);
307 impl Lcg {
308 fn new(seed: u64) -> Self {
309 Self(seed)
310 }
311 fn next_f64(&mut self) -> f64 {
312 self.0 = self
313 .0
314 .wrapping_mul(6_364_136_223_846_793_005)
315 .wrapping_add(1_442_695_040_888_963_407);
316 (self.0 >> 11) as f64 / (1u64 << 53) as f64
317 }
318 }
319
320 fn brute_range<C: CoordType, const D: usize>(
321 pts: &[(Point<C, D>, EntryId)],
322 bbox: &BBox<C, D>,
323 ) -> Vec<EntryId> {
324 let mut ids: Vec<EntryId> = pts
325 .iter()
326 .filter(|(p, _)| bbox.contains_point(p))
327 .map(|(_, id)| *id)
328 .collect();
329 ids.sort_by_key(|id| id.0);
330 ids
331 }
332
333 #[test]
334 fn insert_and_len() {
335 let mut grid = GridIndex::<u32, f64, 2>::default();
336 assert_eq!(grid.len(), 0);
337 grid.insert(Point::new([0.5, 0.5]), 1u32);
338 assert_eq!(grid.len(), 1);
339 grid.insert(Point::new([1.5, 1.5]), 2u32);
340 assert_eq!(grid.len(), 2);
341 }
342
343 #[test]
344 fn range_query_basic() {
345 let mut grid = GridIndex::<u32, f64, 2>::default();
346 let id1 = grid.insert(Point::new([0.5, 0.5]), 1u32);
347 let id2 = grid.insert(Point::new([1.5, 1.5]), 2u32);
348 let _id3 = grid.insert(Point::new([5.0, 5.0]), 3u32);
349 let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([2.0, 2.0]));
350 let mut got: Vec<EntryId> = grid
351 .range_query(&bbox)
352 .into_iter()
353 .map(|(id, _)| id)
354 .collect();
355 got.sort_by_key(|id| id.0);
356 assert_eq!(got, vec![id1, id2]);
357 }
358
359 #[test]
360 fn remove_works() {
361 let mut grid = GridIndex::<u32, f64, 2>::default();
362 let id1 = grid.insert(Point::new([1.0, 1.0]), 10u32);
363 let id2 = grid.insert(Point::new([2.0, 2.0]), 20u32);
364 assert_eq!(grid.len(), 2);
365 assert_eq!(grid.remove(id1), Some(10u32));
366 assert_eq!(grid.len(), 1);
367 let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([3.0, 3.0]));
368 let ids: Vec<EntryId> = grid
369 .range_query(&bbox)
370 .into_iter()
371 .map(|(id, _)| id)
372 .collect();
373 assert!(!ids.contains(&id1));
374 assert!(ids.contains(&id2));
375 }
376
377 #[test]
378 fn kind_is_grid() {
379 assert_eq!(
380 GridIndex::<u32, f64, 2>::default().kind(),
381 BackendKind::Grid
382 );
383 }
384
385 #[test]
386 fn cell_coord_d2() {
387 let origin = Point::new([0.0_f64, 0.0]);
388 let cell_size = [1.0_f64, 1.0];
389 let point = Point::new([2.5_f64, 3.7]);
390 let coord = cell_coord(&point, &origin, &cell_size);
391 assert_eq!(coord.0, [2, 3]);
392 }
393
394 #[test]
395 fn cell_coord_d3() {
396 let origin = Point::new([0.0_f64; 3]);
397 let cell_size = [2.0_f64; 3];
398 let point = Point::new([5.0_f64, 7.0, 9.0]);
399 let coord = cell_coord(&point, &origin, &cell_size);
400 assert_eq!(coord.0, [2, 3, 4]);
401 }
402
403 #[test]
404 fn cell_coord_d4() {
405 let origin = Point::new([0.0_f64; 4]);
406 let cell_size = [10.0_f64; 4];
407 let point = Point::new([15.0_f64, 25.0, 35.0, 45.0]);
408 let coord = cell_coord(&point, &origin, &cell_size);
409 assert_eq!(coord.0, [1, 2, 3, 4]);
410 }
411
412 #[test]
413 fn cell_coord_d5() {
414 let origin = Point::new([0.0_f64; 5]);
415 let cell_size = [5.0_f64; 5];
416 let point = Point::new([0.0_f64, 5.0, 10.0, 15.0, 20.0]);
417 let coord = cell_coord(&point, &origin, &cell_size);
418 assert_eq!(coord.0, [0, 1, 2, 3, 4]);
419 }
420
421 #[test]
422 fn cell_coord_d6() {
423 let origin = Point::new([0.0_f64; 6]);
424 let cell_size = [3.0_f64; 6];
425 let point = Point::new([3.0_f64, 6.0, 9.0, 12.0, 15.0, 18.0]);
426 let coord = cell_coord(&point, &origin, &cell_size);
427 assert_eq!(coord.0, [1, 2, 3, 4, 5, 6]);
428 }
429
430 #[test]
431 fn cell_coord_negative() {
432 let origin = Point::new([0.0_f64, 0.0]);
433 let cell_size = [1.0_f64, 1.0];
434 let point = Point::new([-1.5_f64, -0.5]);
435 let coord = cell_coord(&point, &origin, &cell_size);
436 assert_eq!(coord.0, [-2, -1]);
437 }
438
439 #[test]
440 fn uniform_data_approx_one_point_per_cell() {
441 let n = 10_000usize;
442 let mut rng = Lcg::new(42);
443 let entries: Vec<(Point<f64, 2>, usize)> = (0..n)
444 .map(|i| {
445 (
446 Point::new([rng.next_f64() * 1000.0, rng.next_f64() * 1000.0]),
447 i,
448 )
449 })
450 .collect();
451 let grid = GridIndex::<usize, f64, 2>::bulk_load(entries);
452 let cell_count = grid.cells.len();
453 let avg = n as f64 / cell_count as f64;
454 assert!(
455 (0.5..=4.0).contains(&avg),
456 "avg points/cell = {avg:.2}, cell_count = {cell_count}"
457 );
458 }
459
460 #[test]
461 fn range_query_vs_brute_force_2d_10k() {
462 let n = 10_000usize;
463 let mut rng = Lcg::new(99);
464 let mut grid = GridIndex::<usize, f64, 2>::new([10.0_f64, 10.0], Point::new([0.0, 0.0]));
465 let mut pt_ids = Vec::new();
466 for i in 0..n {
467 let p = Point::new([rng.next_f64() * 1000.0, rng.next_f64() * 1000.0]);
468 let id = grid.insert(p, i);
469 pt_ids.push((p, id));
470 }
471 let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([500.0, 500.0]));
472 let mut got: Vec<EntryId> = grid
473 .range_query(&bbox)
474 .into_iter()
475 .map(|(id, _)| id)
476 .collect();
477 got.sort_by_key(|id| id.0);
478 let expected = brute_range(&pt_ids, &bbox);
479 assert_eq!(got, expected, "2D 10k range query mismatch");
480 }
481
482 fn pt2d() -> impl Strategy<Value = Point<f64, 2>> {
483 (0.0_f64..1000.0, 0.0_f64..1000.0).prop_map(|(x, y)| Point::new([x, y]))
484 }
485
486 fn bbox2d() -> impl Strategy<Value = BBox<f64, 2>> {
487 (
488 0.0_f64..900.0,
489 0.0_f64..900.0,
490 10.0_f64..200.0,
491 10.0_f64..200.0,
492 )
493 .prop_map(|(x, y, w, h)| BBox::new(Point::new([x, y]), Point::new([x + w, y + h])))
494 }
495
496 proptest! {
498 #![proptest_config(proptest::test_runner::Config {
499 cases: 100,
500 ..Default::default()
501 })]
502
503 #[test]
504 fn prop_insert_remove_round_trip_grid(
505 pts in prop::collection::vec(pt2d(), 1..50),
506 remove_indices in prop::collection::vec(0usize..50, 0..25),
507 ) {
508 let mut grid = GridIndex::<usize, f64, 2>::new(
509 [10.0_f64, 10.0],
510 Point::new([0.0_f64, 0.0]),
511 );
512 let mut inserted: Vec<(Point<f64, 2>, EntryId)> = Vec::new();
513 for (i, &p) in pts.iter().enumerate() {
514 let id = grid.insert(p, i);
515 inserted.push((p, id));
516 }
517 let mut removed_ids: Vec<EntryId> = Vec::new();
518 for &ri in &remove_indices {
519 let idx = ri % inserted.len();
520 let (_, id) = inserted[idx];
521 if !removed_ids.contains(&id) {
522 let result = grid.remove(id);
523 prop_assert!(result.is_some(), "remove returned None for inserted id");
524 removed_ids.push(id);
525 }
526 }
527 let full_bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([1000.0, 1000.0]));
529 let remaining_ids: Vec<EntryId> = grid.range_query(&full_bbox)
530 .into_iter()
531 .map(|(id, _)| id)
532 .collect();
533 for &removed_id in &removed_ids {
534 prop_assert!(
535 !remaining_ids.contains(&removed_id),
536 "removed entry {:?} still appears in range query",
537 removed_id
538 );
539 }
540 let expected_len = inserted.len() - removed_ids.len();
541 prop_assert_eq!(grid.len(), expected_len);
542 }
543 }
544
545 proptest! {
547 #![proptest_config(proptest::test_runner::Config {
548 cases: 200,
549 ..Default::default()
550 })]
551
552 #[test]
553 fn prop_range_query_oracle_grid(
554 pts in prop::collection::vec(pt2d(), 1..100),
555 bbox in bbox2d(),
556 ) {
557 let mut grid = GridIndex::<usize, f64, 2>::new(
558 [10.0_f64, 10.0],
559 Point::new([0.0_f64, 0.0]),
560 );
561 let mut pt_ids: Vec<(Point<f64, 2>, EntryId)> = Vec::new();
562 for (i, p) in pts.iter().enumerate() {
563 let id = grid.insert(*p, i);
564 pt_ids.push((*p, id));
565 }
566 let mut got: Vec<EntryId> =
567 grid.range_query(&bbox).into_iter().map(|(id, _)| id).collect();
568 got.sort_by_key(|id| id.0);
569 let expected = brute_range(&pt_ids, &bbox);
570 prop_assert_eq!(got, expected);
571 }
572 }
573}