use crate::error::{SpatialError, SpatialResult};
use crate::safe_conversions::*;
use scirs2_core::numeric::Float;
use std::collections::HashMap;
use std::marker::PhantomData;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct CellKey {
coords: [i64; MAX_DIMS],
dims: usize,
}
const MAX_DIMS: usize = 16;
impl CellKey {
fn new(coords: &[i64]) -> SpatialResult<Self> {
if coords.len() > MAX_DIMS {
return Err(SpatialError::ValueError(format!(
"Grid index supports at most {} dimensions, got {}",
MAX_DIMS,
coords.len()
)));
}
let mut arr = [0i64; MAX_DIMS];
for (i, &c) in coords.iter().enumerate() {
arr[i] = c;
}
Ok(CellKey {
coords: arr,
dims: coords.len(),
})
}
fn dim_coords(&self) -> &[i64] {
&self.coords[..self.dims]
}
}
#[derive(Clone, Debug)]
struct StoredPoint<T: Float> {
id: usize,
coords: Vec<T>,
}
#[derive(Clone, Debug)]
pub struct GridIndex<T: Float> {
cell_size: T,
inv_cell_size: T,
n_dims: usize,
cells: HashMap<CellKey, Vec<StoredPoint<T>>>,
count: usize,
_phantom: PhantomData<T>,
}
impl<T: Float + 'static> GridIndex<T> {
pub fn new(cell_size: f64, n_dims: usize) -> SpatialResult<Self> {
if cell_size <= 0.0 {
return Err(SpatialError::ValueError(
"Cell size must be positive".to_string(),
));
}
if n_dims == 0 || n_dims > MAX_DIMS {
return Err(SpatialError::ValueError(format!(
"Dimensions must be in [1, {}], got {}",
MAX_DIMS, n_dims
)));
}
let cs: T = safe_from(cell_size, "grid cell_size")?;
let ics: T = safe_from(1.0 / cell_size, "grid inv_cell_size")?;
Ok(GridIndex {
cell_size: cs,
inv_cell_size: ics,
n_dims,
cells: HashMap::new(),
count: 0,
_phantom: PhantomData,
})
}
pub fn len(&self) -> usize {
self.count
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub fn cell_count(&self) -> usize {
self.cells.len()
}
pub fn dims(&self) -> usize {
self.n_dims
}
pub fn cell_size(&self) -> T {
self.cell_size
}
pub fn insert(&mut self, id: usize, coords: &[T]) -> SpatialResult<()> {
if coords.len() != self.n_dims {
return Err(SpatialError::DimensionError(format!(
"Expected {} dims, got {}",
self.n_dims,
coords.len()
)));
}
let key = self.cell_key(coords)?;
let sp = StoredPoint {
id,
coords: coords.to_vec(),
};
self.cells.entry(key).or_default().push(sp);
self.count += 1;
Ok(())
}
pub fn insert_batch(&mut self, points: &[(usize, Vec<T>)]) -> SpatialResult<()> {
for (id, coords) in points {
self.insert(*id, coords)?;
}
Ok(())
}
pub fn remove(&mut self, id: usize) -> bool {
for cell in self.cells.values_mut() {
if let Some(pos) = cell.iter().position(|sp| sp.id == id) {
cell.swap_remove(pos);
self.count -= 1;
return true;
}
}
false
}
pub fn clear(&mut self) {
self.cells.clear();
self.count = 0;
}
pub fn query_radius(&self, query: &[T], radius: f64) -> SpatialResult<(Vec<usize>, Vec<T>)> {
if query.len() != self.n_dims {
return Err(SpatialError::DimensionError(format!(
"Expected {} dims, got {}",
self.n_dims,
query.len()
)));
}
if radius < 0.0 {
return Err(SpatialError::ValueError(
"Radius must be non-negative".to_string(),
));
}
let r: T = safe_from(radius, "query radius")?;
let r_sq = r * r;
let cell_radius = (radius / self.to_f64(self.cell_size)).ceil() as i64;
let center_key = self.cell_key(query)?;
let mut results: Vec<(usize, T)> = Vec::new();
self.for_each_neighbor_cell(¢er_key, cell_radius, |cell_points| {
for sp in cell_points {
let d_sq = self.squared_distance(query, &sp.coords);
if d_sq <= r_sq {
results.push((sp.id, d_sq.sqrt()));
}
}
});
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let (ids, dists): (Vec<_>, Vec<_>) = results.into_iter().unzip();
Ok((ids, dists))
}
pub fn query_knn(&self, query: &[T], k: usize) -> SpatialResult<(Vec<usize>, Vec<T>)> {
if query.len() != self.n_dims {
return Err(SpatialError::DimensionError(format!(
"Expected {} dims, got {}",
self.n_dims,
query.len()
)));
}
if k == 0 {
return Ok((vec![], vec![]));
}
let k = k.min(self.count);
let center_key = self.cell_key(query)?;
let mut cell_radius: i64 = 1;
let max_radius: i64 = 64;
loop {
let mut candidates: Vec<(usize, T)> = Vec::new();
self.for_each_neighbor_cell(¢er_key, cell_radius, |cell_points| {
for sp in cell_points {
let d = self.euclidean_distance(query, &sp.coords);
candidates.push((sp.id, d));
}
});
candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
if candidates.len() >= k || cell_radius >= max_radius {
candidates.truncate(k);
let (ids, dists): (Vec<_>, Vec<_>) = candidates.into_iter().unzip();
return Ok((ids, dists));
}
cell_radius += 1;
}
}
pub fn has_neighbor(&self, query: &[T], radius: f64) -> SpatialResult<bool> {
if query.len() != self.n_dims {
return Err(SpatialError::DimensionError(format!(
"Expected {} dims, got {}",
self.n_dims,
query.len()
)));
}
let r: T = safe_from(radius, "has_neighbor radius")?;
let r_sq = r * r;
let cell_radius = (radius / self.to_f64(self.cell_size)).ceil() as i64;
let center_key = self.cell_key(query)?;
let mut found = false;
self.for_each_neighbor_cell(¢er_key, cell_radius, |cell_points| {
if found {
return;
}
for sp in cell_points {
let d_sq = self.squared_distance(query, &sp.coords);
if d_sq <= r_sq {
found = true;
return;
}
}
});
Ok(found)
}
pub fn all_ids(&self) -> Vec<usize> {
let mut ids = Vec::with_capacity(self.count);
for cell in self.cells.values() {
for sp in cell {
ids.push(sp.id);
}
}
ids
}
fn cell_key(&self, coords: &[T]) -> SpatialResult<CellKey> {
let mut ints = Vec::with_capacity(self.n_dims);
for &c in coords.iter().take(self.n_dims) {
let idx = (c * self.inv_cell_size).floor();
let i = self.to_i64(idx);
ints.push(i);
}
CellKey::new(&ints)
}
fn for_each_neighbor_cell<F>(&self, center_key: &CellKey, cell_radius: i64, mut f: F)
where
F: FnMut(&[StoredPoint<T>]),
{
let dims = center_key.dims;
let center = center_key.dim_coords();
let mut offsets = vec![vec![0i64; dims]];
for d in 0..dims {
let mut new_offsets = Vec::new();
for existing in &offsets {
for delta in -cell_radius..=cell_radius {
let mut combo = existing.clone();
combo[d] = center[d] + delta;
new_offsets.push(combo);
}
}
offsets = new_offsets;
}
for offset in &offsets {
if let Ok(key) = CellKey::new(offset) {
if let Some(cell_points) = self.cells.get(&key) {
if !cell_points.is_empty() {
f(cell_points);
}
}
}
}
}
fn squared_distance(&self, a: &[T], b: &[T]) -> T {
let mut sum = T::zero();
for i in 0..self.n_dims {
let d = a[i] - b[i];
sum = sum + d * d;
}
sum
}
fn euclidean_distance(&self, a: &[T], b: &[T]) -> T {
self.squared_distance(a, b).sqrt()
}
fn to_f64(&self, val: T) -> f64 {
val.to_f64().unwrap_or(0.0)
}
fn to_i64(&self, val: T) -> i64 {
val.to_f64().unwrap_or(0.0) as i64
}
}
impl<T: Float + 'static> GridIndex<T> {
pub fn from_array(
data: &scirs2_core::ndarray::ArrayView2<T>,
cell_size: f64,
) -> SpatialResult<Self> {
let n = data.nrows();
let d = data.ncols();
let mut grid = GridIndex::new(cell_size, d)?;
for i in 0..n {
let row = data.row(i).to_vec();
grid.insert(i, &row)?;
}
Ok(grid)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_create_grid() {
let grid = GridIndex::<f64>::new(1.0, 2);
assert!(grid.is_ok());
let grid = grid.expect("should create");
assert_eq!(grid.len(), 0);
assert!(grid.is_empty());
assert_eq!(grid.dims(), 2);
}
#[test]
fn test_invalid_cell_size() {
let result = GridIndex::<f64>::new(0.0, 2);
assert!(result.is_err());
let result = GridIndex::<f64>::new(-1.0, 2);
assert!(result.is_err());
}
#[test]
fn test_invalid_dims() {
let result = GridIndex::<f64>::new(1.0, 0);
assert!(result.is_err());
let result = GridIndex::<f64>::new(1.0, 17);
assert!(result.is_err());
}
#[test]
fn test_insert_and_count() {
let mut grid = GridIndex::<f64>::new(1.0, 2).expect("create");
grid.insert(0, &[0.5, 0.5]).expect("insert");
grid.insert(1, &[1.5, 0.5]).expect("insert");
grid.insert(2, &[0.5, 1.5]).expect("insert");
assert_eq!(grid.len(), 3);
assert!(!grid.is_empty());
}
#[test]
fn test_insert_wrong_dims() {
let mut grid = GridIndex::<f64>::new(1.0, 2).expect("create");
let result = grid.insert(0, &[1.0, 2.0, 3.0]);
assert!(result.is_err());
}
#[test]
fn test_remove() {
let mut grid = GridIndex::<f64>::new(1.0, 2).expect("create");
grid.insert(0, &[0.0, 0.0]).expect("insert");
grid.insert(1, &[1.0, 1.0]).expect("insert");
assert_eq!(grid.len(), 2);
assert!(grid.remove(0));
assert_eq!(grid.len(), 1);
assert!(!grid.remove(0));
assert_eq!(grid.len(), 1);
}
#[test]
fn test_clear() {
let mut grid = GridIndex::<f64>::new(1.0, 2).expect("create");
grid.insert(0, &[0.0, 0.0]).expect("insert");
grid.insert(1, &[1.0, 1.0]).expect("insert");
grid.clear();
assert_eq!(grid.len(), 0);
assert!(grid.is_empty());
}
#[test]
fn test_query_radius_basic() {
let mut grid = GridIndex::<f64>::new(1.0, 2).expect("create");
grid.insert(0, &[0.0, 0.0]).expect("insert");
grid.insert(1, &[0.5, 0.5]).expect("insert");
grid.insert(2, &[3.0, 3.0]).expect("insert");
let (ids, dists) = grid.query_radius(&[0.0, 0.0], 1.0).expect("query");
assert_eq!(ids.len(), 2);
assert!(ids.contains(&0));
assert!(ids.contains(&1));
for i in 1..dists.len() {
assert!(dists[i] >= dists[i - 1]);
}
}
#[test]
fn test_query_radius_empty() {
let mut grid = GridIndex::<f64>::new(1.0, 2).expect("create");
grid.insert(0, &[0.0, 0.0]).expect("insert");
let (ids, _) = grid.query_radius(&[10.0, 10.0], 0.5).expect("query");
assert!(ids.is_empty());
}
#[test]
fn test_query_radius_negative() {
let grid = GridIndex::<f64>::new(1.0, 2).expect("create");
let result = grid.query_radius(&[0.0, 0.0], -1.0);
assert!(result.is_err());
}
#[test]
fn test_knn_basic() {
let mut grid = GridIndex::<f64>::new(1.0, 2).expect("create");
grid.insert(0, &[0.0, 0.0]).expect("insert");
grid.insert(1, &[1.0, 0.0]).expect("insert");
grid.insert(2, &[0.0, 1.0]).expect("insert");
grid.insert(3, &[10.0, 10.0]).expect("insert");
let (ids, dists) = grid.query_knn(&[0.1, 0.1], 2).expect("knn");
assert_eq!(ids.len(), 2);
assert_eq!(ids[0], 0); assert!(dists[0] <= dists[1]);
}
#[test]
fn test_knn_k_zero() {
let grid = GridIndex::<f64>::new(1.0, 2).expect("create");
let (ids, dists) = grid.query_knn(&[0.0, 0.0], 0).expect("knn");
assert!(ids.is_empty());
assert!(dists.is_empty());
}
#[test]
fn test_has_neighbor() {
let mut grid = GridIndex::<f64>::new(1.0, 2).expect("create");
grid.insert(0, &[0.0, 0.0]).expect("insert");
assert!(grid.has_neighbor(&[0.5, 0.0], 1.0).expect("check"));
assert!(!grid.has_neighbor(&[5.0, 5.0], 1.0).expect("check"));
}
#[test]
fn test_from_array() {
let pts = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
let grid = GridIndex::<f64>::from_array(&pts.view(), 1.0).expect("build");
assert_eq!(grid.len(), 4);
}
#[test]
fn test_all_ids() {
let mut grid = GridIndex::<f64>::new(1.0, 2).expect("create");
grid.insert(10, &[0.0, 0.0]).expect("insert");
grid.insert(20, &[1.0, 1.0]).expect("insert");
grid.insert(30, &[2.0, 2.0]).expect("insert");
let mut ids = grid.all_ids();
ids.sort();
assert_eq!(ids, vec![10, 20, 30]);
}
#[test]
fn test_cell_count() {
let mut grid = GridIndex::<f64>::new(1.0, 2).expect("create");
grid.insert(0, &[0.1, 0.1]).expect("insert");
grid.insert(1, &[0.2, 0.2]).expect("insert");
grid.insert(2, &[1.5, 1.5]).expect("insert");
assert_eq!(grid.cell_count(), 2);
}
#[test]
fn test_3d_grid() {
let mut grid = GridIndex::<f64>::new(1.0, 3).expect("create");
grid.insert(0, &[0.0, 0.0, 0.0]).expect("insert");
grid.insert(1, &[0.5, 0.5, 0.5]).expect("insert");
grid.insert(2, &[5.0, 5.0, 5.0]).expect("insert");
let (ids, _) = grid.query_radius(&[0.0, 0.0, 0.0], 1.0).expect("query");
assert_eq!(ids.len(), 2);
}
#[test]
fn test_f32_grid() {
let mut grid = GridIndex::<f32>::new(1.0, 2).expect("create");
grid.insert(0, &[0.0f32, 0.0]).expect("insert");
grid.insert(1, &[0.5f32, 0.5]).expect("insert");
let (ids, dists) = grid.query_radius(&[0.0f32, 0.0], 1.0).expect("query");
assert_eq!(ids.len(), 2);
assert_relative_eq!(dists[0], 0.0f32, epsilon = 1e-6);
}
#[test]
fn test_negative_coords() {
let mut grid = GridIndex::<f64>::new(1.0, 2).expect("create");
grid.insert(0, &[-1.0, -1.0]).expect("insert");
grid.insert(1, &[-0.5, -0.5]).expect("insert");
grid.insert(2, &[0.5, 0.5]).expect("insert");
let (ids, _) = grid.query_radius(&[-0.8, -0.8], 0.5).expect("query");
assert!(ids.contains(&0) || ids.contains(&1));
}
#[test]
fn test_insert_batch() {
let mut grid = GridIndex::<f64>::new(1.0, 2).expect("create");
let points = vec![
(0, vec![0.0, 0.0]),
(1, vec![1.0, 1.0]),
(2, vec![2.0, 2.0]),
];
grid.insert_batch(&points).expect("batch insert");
assert_eq!(grid.len(), 3);
}
#[test]
fn test_large_dataset() {
let n = 500;
let mut grid = GridIndex::<f64>::new(0.1, 2).expect("create");
for i in 0..n {
let x = (i as f64) * 0.01;
let y = (i as f64) * 0.02;
grid.insert(i, &[x, y]).expect("insert");
}
assert_eq!(grid.len(), n);
let (ids, _) = grid.query_radius(&[0.0, 0.0], 0.5).expect("query");
assert!(!ids.is_empty());
}
#[test]
fn test_knn_expanding_search() {
let mut grid = GridIndex::<f64>::new(0.1, 2).expect("create");
grid.insert(0, &[5.0, 5.0]).expect("insert");
grid.insert(1, &[5.1, 5.1]).expect("insert");
let (ids, _) = grid.query_knn(&[0.0, 0.0], 2).expect("knn");
assert_eq!(ids.len(), 2);
}
}