use std::collections::HashMap;
#[derive(Clone, PartialEq, Eq, Hash)]
struct CellKey(Vec<i32>);
pub struct SpatialHash {
ndim: usize,
#[allow(unused)]
cell_size: f64,
inv_cell_size: f64,
cells: HashMap<CellKey, Vec<usize>>,
}
impl SpatialHash {
pub fn new(ndim: usize, cell_size: f64) -> Self {
if ndim > 3 {
eprintln!(
"Warning: SpatialHash in {ndim}D may be inefficient!"
);
}
Self {
ndim,
cell_size,
inv_cell_size: 1.0 / cell_size,
cells: HashMap::new(),
}
}
fn cell_of(&self, x: &[f64]) -> CellKey {
let key: Vec<i32> = x.iter()
.map(|&x_loop| (x_loop * self.inv_cell_size).floor() as i32)
.collect();
CellKey(key)
}
pub fn insert(&mut self, idx: usize, x: &[f64]) {
debug_assert_eq!(x.len(), self.ndim);
let key = self.cell_of(x);
self.cells.entry(key).or_default().push(idx);
}
pub fn build(points: &[Vec<f64>], cell_size: f64) -> Self {
let ndim = points[0].len();
let mut hash = Self::new(ndim, cell_size);
for (i, p) in points.iter().enumerate() {
hash.insert(i, p);
}
hash
}
pub fn for_each_neighbor<F: FnMut(usize)>(&self, x: &[f64], mut f: F)
{
debug_assert_eq!(x.len(), self.ndim);
let base = self.cell_of(x);
let mut offset = vec![0i32; self.ndim];
self.recurse_offsets(0, &base, &mut offset, &mut f);
}
fn recurse_offsets<F: FnMut(usize)>(
&self,
depth: usize,
base: &CellKey,
offset: &mut [i32],
f: &mut F,
)
{
if depth == self.ndim {
let mut key = base.0.clone();
for i in 0..self.ndim {
key[i] += offset[i];
}
if let Some(bucket) = self.cells.get(&CellKey(key)) {
for &idx in bucket {
f(idx);
}
}
return;
}
for &d in &[-1, 0, 1] {
offset[depth] = d;
self.recurse_offsets(depth + 1, base, offset, f);
}
}
}
#[cfg(test)]
mod tests {
use super::SpatialHash;
fn collect_neighbors(hash: &SpatialHash, x: &[f64]) -> Vec<usize> {
let mut out = Vec::new();
hash.for_each_neighbor(x, |i| out.push(i));
out.sort_unstable();
out
}
#[test]
fn inserts_and_finds_self() {
let points = vec![
vec![0.1, 0.2],
vec![1.5, 2.5],
vec![9.9, -3.2],
];
let hash = SpatialHash::build(&points, 1.0);
for (i, p) in points.iter().enumerate() {
let neighbors = collect_neighbors(&hash, p);
assert!(
neighbors.contains(&i),
"Point {i} not found in its own neighbor list"
);
}
}
#[test]
fn nearby_points_are_found() {
let points = vec![
vec![0.1, 0.1],
vec![0.9, 0.8],
vec![1.1, 1.0], vec![5.0, 5.0], ];
let hash = SpatialHash::build(&points, 1.0);
let n0 = collect_neighbors(&hash, &points[0]);
assert!(n0.contains(&1));
assert!(n0.contains(&2));
assert!(!n0.contains(&3));
let n1 = collect_neighbors(&hash, &points[1]);
assert!(n1.contains(&0));
assert!(n1.contains(&2));
assert!(!n1.contains(&3));
let n2 = collect_neighbors(&hash, &points[2]);
assert!(n2.contains(&1));
assert!(n2.contains(&0));
assert!(!n2.contains(&3));
let n3 = collect_neighbors(&hash, &points[3]);
assert!(!n3.contains(&0));
assert!(!n3.contains(&1));
assert!(!n3.contains(&2));
}
#[test]
fn works_in_1d() {
let points = vec![
vec![0.1],
vec![0.9],
vec![1.1],
vec![5.0],
];
let hash = SpatialHash::build(&points, 1.0);
let n0 = collect_neighbors(&hash, &points[0]);
assert!(n0.contains(&1));
assert!(n0.contains(&2));
assert!(!n0.contains(&3));
}
#[test]
fn works_in_3d() {
let points = vec![
vec![0.1;3],
vec![0.9;3],
vec![1.1,1.0,1.0],
vec![5.0,5.0,5.0],
];
let hash = SpatialHash::build(&points, 1.0);
let n0 = collect_neighbors(&hash, &points[0]);
assert!(n0.contains(&1));
assert!(n0.contains(&2));
assert!(!n0.contains(&3));
}
#[test]
fn works_for_irregular_layouts() {
let points = vec![
vec![0., 0.],
vec![0.5, 0.866],
vec![1., 0.],
vec![1.5, 0.866],
vec![5.0, 5.0],
];
let hash = SpatialHash::build(&points, 1.0);
let n0 = collect_neighbors(&hash, &points[0]);
assert!(n0.contains(&1));
assert!(n0.contains(&2));
assert!(n0.contains(&3));
assert!(!n0.contains(&4));
}
#[test]
fn insertion_order_does_not_matter() {
let points = vec![
vec![0.1, 0.1],
vec![0.9, 0.8],
vec![1.1, 1.0],
];
let mut reversed = points.clone();
reversed.reverse();
let hash1 = SpatialHash::build(&points, 1.0);
let hash2 = SpatialHash::build(&reversed, 1.0);
let n1 = collect_neighbors(&hash1, &points[0]);
let n2 = collect_neighbors(&hash2, &points[0]);
assert_eq!(n1.len(), n2.len());
}
}