csrk 1.1.4

Sparse Gaussian Process regression with compactly supported radial kernels
Documentation
//! Spatial Hash struct for sparse matrix operations

// Import hashmap
use std::collections::HashMap;
// Define a CellKey struct which is secretly a vector
/// A hash key type for a vector of i32 values
#[derive(Clone, PartialEq, Eq, Hash)]
struct CellKey(Vec<i32>);

/// The struct contianing the hash of cell indices for some points
pub struct SpatialHash {
    /// The dimensionality of our space
    ndim: usize,
    #[allow(unused)]
    /// The size of a cell
    cell_size: f64,
    /// 1/cell size tends to be computed a lot
    inv_cell_size: f64,
    /// The hash map to go from cell indices (hashed) back to indicies
    /// of the list of points hashed at build.
    cells: HashMap<CellKey, Vec<usize>>,
}
impl SpatialHash {
    /// Constructor for a SpatialHash
    pub fn new(ndim: usize, cell_size: f64) -> Self {
        // Warn user about dimensions
        if ndim > 3 {
            eprintln!(
                "Warning: SpatialHash in {ndim}D may be inefficient!"
            );
        }
        // Return new struct
        Self {
            ndim, 
            cell_size, 
            inv_cell_size: 1.0 / cell_size, 
            cells: HashMap::new(),
        }
    }
    /// Return a CellKey corresponding to
    /// a given point's spatial coordinates
    fn cell_of(&self, x: &[f64]) -> CellKey {
        // Get a vector of i32 values holding the index of each cell
        let key: Vec<i32> = x.iter()
            .map(|&x_loop| (x_loop * self.inv_cell_size).floor() as i32)
            .collect();
        // Redefine vector as hashable Cellkey
        CellKey(key)
    }
    /// Method for assigning the index of a point to the internal HashTable
    pub fn insert(&mut self, idx: usize, x: &[f64]) {
        // check length of x
        debug_assert_eq!(x.len(), self.ndim);
        let key = self.cell_of(x);
        // assign the value idx to the table with key
        self.cells.entry(key).or_default().push(idx);
    }
    /// Populate a SpatialHash map with points
    pub fn build(points: &[Vec<f64>], cell_size: f64) -> Self {
        // Identify dimensionality of data
        let ndim = points[0].len();
        // Define a mutable SpatialHash
        let mut hash = Self::new(ndim, cell_size);
        // Loop through points and hash them
        for (i, p) in points.iter().enumerate() {
            hash.insert(i, p);
        }
        // return new hash struct
        hash
    }
    /// Loop the neighboring cells of point x with a function
    ///   evaluated on each point in those cells
    pub fn for_each_neighbor<F: FnMut(usize)>(&self, x: &[f64], mut f: F)
    {
        // check length of x
        debug_assert_eq!(x.len(), self.ndim);
        // Identify the cell this point is in
        let base = self.cell_of(x);
        // Allocate the offset vector
        let mut offset = vec![0i32; self.ndim];
        // Recursively execute f on points in adjacent cells
        //   by looping a 3^n cell frame
        self.recurse_offsets(0, &base, &mut offset, &mut f);

    }
    /// Execute f in cell frame
    fn recurse_offsets<F: FnMut(usize)>(
        &self,
        depth: usize,
        base: &CellKey,
        offset: &mut [i32],
        f: &mut F,
    )
    {
        // If we have constructed a 3^n cell frame, it's time to
        //   iterate and call f
        if depth == self.ndim {
            // Clone the CellKey (it will be consumed)
            let mut key = base.0.clone();
            // This increments the cell key based on the offset
            for i in 0..self.ndim {
                key[i] += offset[i];
            }
            // Find all the points in a given cell
            if let Some(bucket) = self.cells.get(&CellKey(key)) {
                // loop those points
                for &idx in bucket {
                    // Execute f
                    f(idx);
                }
            }
            return;
        }
        // Recurse through cell offsets in this dimension
        for &d in &[-1, 0, 1] {
            offset[depth] = d;
            self.recurse_offsets(depth + 1, base, offset, f);
        }

    }
}


#[cfg(test)]
mod tests {
    use super::SpatialHash;

    fn collect_neighbors(hash: &SpatialHash, x: &[f64]) -> Vec<usize> {
        let mut out = Vec::new();
        hash.for_each_neighbor(x, |i| out.push(i));
        out.sort_unstable();
        out
    }

    #[test]
    fn inserts_and_finds_self() {
        // Define some points
        let points = vec![
            vec![0.1, 0.2],
            vec![1.5, 2.5],
            vec![9.9, -3.2],
        ];
        // Instantiate a hash
        let hash = SpatialHash::build(&points, 1.0);
        // Loop the points
        for (i, p) in points.iter().enumerate() {
            let neighbors = collect_neighbors(&hash, p);
            assert!(
                neighbors.contains(&i),
                "Point {i} not found in its own neighbor list"
            );
        }
    }

    #[test]
    fn nearby_points_are_found() {
        // Define some points
        let points = vec![
            vec![0.1, 0.1],
            vec![0.9, 0.8],
            vec![1.1, 1.0], // adjacent cell
            vec![5.0, 5.0], // far away
        ];
        // build the hash
        let hash = SpatialHash::build(&points, 1.0);
        // Find neighboors for point at place zero
        let n0 = collect_neighbors(&hash, &points[0]);
        assert!(n0.contains(&1));
        assert!(n0.contains(&2));
        assert!(!n0.contains(&3));
        // Find neighboors for point at place one
        let n1 = collect_neighbors(&hash, &points[1]);
        assert!(n1.contains(&0));
        assert!(n1.contains(&2));
        assert!(!n1.contains(&3));
        // Find neighboors for point at place two
        let n2 = collect_neighbors(&hash, &points[2]);
        assert!(n2.contains(&1));
        assert!(n2.contains(&0));
        assert!(!n2.contains(&3));
        // Find neighboors for point at place three
        let n3 = collect_neighbors(&hash, &points[3]);
        assert!(!n3.contains(&0));
        assert!(!n3.contains(&1));
        assert!(!n3.contains(&2));
    }
    #[test]
    fn works_in_1d() {
        let points = vec![
            vec![0.1],
            vec![0.9],
            vec![1.1],
            vec![5.0],
        ];
        // build the hash
        let hash = SpatialHash::build(&points, 1.0);
        // Find neighboors for point at place zero
        let n0 = collect_neighbors(&hash, &points[0]);
        assert!(n0.contains(&1));
        assert!(n0.contains(&2));
        assert!(!n0.contains(&3));
    }
    #[test]
    fn works_in_3d() {
        let points = vec![
            vec![0.1;3],
            vec![0.9;3],
            vec![1.1,1.0,1.0],
            vec![5.0,5.0,5.0],
        ];
        // build the hash
        let hash = SpatialHash::build(&points, 1.0);
        // Find neighboors for point at place zero
        let n0 = collect_neighbors(&hash, &points[0]);
        assert!(n0.contains(&1));
        assert!(n0.contains(&2));
        assert!(!n0.contains(&3));
    }
    #[test]
    fn works_for_irregular_layouts() {
        // roughly hexagonal / skewed layout
        let points = vec![
            vec![0., 0.],
            vec![0.5, 0.866],
            vec![1., 0.],
            vec![1.5, 0.866],
            vec![5.0, 5.0],
        ];
        // build the hash
        let hash = SpatialHash::build(&points, 1.0);
        // Find neighboors for point at place zero
        let n0 = collect_neighbors(&hash, &points[0]);
        assert!(n0.contains(&1));
        assert!(n0.contains(&2));
        assert!(n0.contains(&3));
        assert!(!n0.contains(&4));
    }
    #[test]
    fn insertion_order_does_not_matter() {
        // Some points
        let points = vec![
            vec![0.1, 0.1],
            vec![0.9, 0.8],
            vec![1.1, 1.0],
        ];
        // define reverse vector
        let mut reversed = points.clone();
        reversed.reverse();
        // build the hashes
        let hash1 = SpatialHash::build(&points, 1.0);
        let hash2 = SpatialHash::build(&reversed, 1.0);
        //  find neighbors of each point
        let n1 = collect_neighbors(&hash1, &points[0]);
        let n2 = collect_neighbors(&hash2, &points[0]);
        assert_eq!(n1.len(), n2.len());
    }
}