imgal 0.3.0

A fast and open-source scientific image processing and algorithm library.
Documentation
use std::cmp::Ordering;

use ndarray::{Array1, Array2, ArrayBase, ArrayView2, AsArray, Axis, Ix1, Ix2, ViewRepr};

use crate::prelude::*;

/// An immutable K-d tree for fast spatial queries for n-dimensional points.
///
/// The K-d tree itself does not *own* its source data but instead uses a view.
/// This design ensures that imgal's K-d trees are *immutable* once constructed
/// and are intended for lookups only. The `cloud` view (*i.e.* the
/// *n*-dimensional point cloud) points in *D* dimensions with shape `(p, D)`,
/// where `p` is the point and `D` is the dimension/axis of that point.
#[derive(Debug)]
pub struct KDTree<'a, T> {
    /// A view into a point cloud array with shape `(p, D)`.
    pub cloud: ArrayView2<'a, T>,
    /// The K-d tree node vector that each `Node` indexes into.
    pub nodes: Vec<Node>,
    /// The root of the K-d tree.
    pub root: Option<usize>,
}

/// A K-d-tree node for an immutable K-d tree.
///
/// KD-trees are constructed with `Node`s. These `Nodes` are stored in a
/// `Vec<Node>` and the `left` and `right` fields store indices into the `Node`
/// vector. The axis the split occurs at is stored in `split_axis` and the index
/// into the source array is stored in the `point_index` field.
#[derive(Debug)]
pub struct Node {
    /// The axis this node was split on.
    pub split_axis: usize,
    /// The node's current point index into the K-d tree's associated point
    /// cloud.
    pub point_index: usize,
    /// The index into the "left" branch relative to this node.
    pub left: Option<usize>,
    /// The index into the "right" branch relative to this node.
    pub right: Option<usize>,
}

impl<'a, T> KDTree<'a, T>
where
    T: 'a + AsNumeric,
{
    /// Create a new K-d tree from an *n*-dimensional point cloud.
    ///
    /// # Description
    ///
    /// Creates a new K-d t ree from an *n*-dimensional point cloud with an
    /// array shape of `(p, D)`, where `p` is the point and `D` is the
    /// dimension/axis of that point. The `KDTree` does not own the point cloud
    /// data, but instead owns an array of `Nodes` that store indices into
    /// the source point cloud.
    ///
    /// # Arguments
    ///
    /// * `cloud`: An array view into a point cloud with shape `(p, D)`.
    ///
    /// # Returns
    ///
    /// * `KDTree<'a, T>`: A K-d tree with radial searching of either point
    ///   indices or coordinates.
    pub fn build<A>(cloud: A) -> Self
    where
        A: AsArray<'a, T, Ix2>,
    {
        let cloud: ArrayBase<ViewRepr<&'a T>, Ix2> = cloud.into();
        let total_points = cloud.dim().0;
        let mut tree = Self {
            cloud,
            nodes: Vec::with_capacity(total_points),
            root: None,
        };
        let mut indices: Vec<usize> = (0..total_points).collect();
        tree.root = tree.recursive_build(&mut indices, 0);
        tree
    }

    /// Search the K-d tree for all point coordinates within a given radius.
    ///
    /// # Description
    ///
    /// Performs a radial search on the K-d tree, returning the coordinates of
    /// all points whose Euclidean distance from the `query` point is less than
    /// or equal to `radius`.
    ///
    /// # Arguments
    ///
    /// * `query`: A slice representing the query point. The query point length
    ///   must match the dimension length of the point cloud.
    /// * `radius`: The radius around the query point to search.
    ///
    /// # Returns
    ///
    /// * `Ok(Array2<T>)`: The point coordinates of all neighboring points to the
    ///   `query` within the `radius`. The returned array has shape `(p, D)`,
    ///   where `p` is the point and `D` is the dimension/axis of that point.
    /// * `Err(ImgalError)`: If `query.len() != self.cloud.dim().1`.
    pub fn search_for_coords<'b, B>(&self, query: B, radius: f64) -> Result<Array2<T>, ImgalError>
    where
        B: AsArray<'b, T, Ix1>,
        T: 'b + AsNumeric,
    {
        let query: ArrayBase<ViewRepr<&'b T>, Ix1> = query.into();
        let q_dims = query.len();
        let c_dims = self.cloud.dim().1;
        if q_dims != c_dims {
            return Err(ImgalError::MismatchedArrayLengths {
                a_arr_name: "query",
                a_arr_len: q_dims,
                b_arr_name: "cloud array shape",
                b_arr_len: c_dims,
            });
        }
        let (coord_indices, _) = self
            .search_for_indices(query, radius)?
            .into_raw_vec_and_offset();
        Ok(self.cloud.select(Axis(0), &coord_indices))
    }

    /// Search the K-d tree for all point indices within the given radius.
    ///
    /// # Description
    ///
    /// Performs a radial search on the K-d tree, returning the indices of all
    /// points whose Euclidean distance from the `query` point is less than or
    /// equal to `radius`.
    ///
    /// # Arguments
    ///
    /// * `query`: A slice representing the query point. The query point length
    ///   must match the dimension length of the point cloud.
    /// * `radius`: The radius around the query point to search.
    ///
    /// # Returns
    ///
    /// * `Ok(Array1<usize>)`: The point indices of all neighboring points to
    ///   the query within the `radius`.
    /// * `Err(ImgalError)`: If `query.len() != self.cloud.dim().1`.
    pub fn search_for_indices<'b, B>(
        &self,
        query: B,
        radius: f64,
    ) -> Result<Array1<usize>, ImgalError>
    where
        B: AsArray<'b, T, Ix1>,
        T: 'b + AsNumeric,
    {
        let query: ArrayBase<ViewRepr<&'b T>, Ix1> = query.into();
        let query = query.to_vec();
        let q_dims = query.len();
        let c_dims = self.cloud.dim().1;
        if q_dims != c_dims {
            return Err(ImgalError::MismatchedArrayLengths {
                a_arr_name: "query",
                a_arr_len: q_dims,
                b_arr_name: "cloud array shape",
                b_arr_len: c_dims,
            });
        }
        let radius_sq = radius * radius;
        let mut results: Vec<usize> = Vec::new();

        // begin recursive searching only if the tree is not empty
        if let Some(root) = self.root {
            self.recursive_search(root, &query, radius_sq, &mut results);
        }
        Ok(Array1::from_vec(results))
    }

    /// Recursively build the K-d tree.
    fn recursive_build(&mut self, indices: &mut [usize], depth: usize) -> Option<usize> {
        if indices.is_empty() {
            return None;
        }
        let n_dims = self.cloud.dim().1;
        let split_axis = depth % n_dims;
        let median = indices.len() / 2;
        indices.select_nth_unstable_by(median, |&a, &b| {
            self.cloud[[a, split_axis]]
                .partial_cmp(&self.cloud[[b, split_axis]])
                .unwrap_or(Ordering::Less)
        });
        let point_index = indices[median];
        // construct the left and right sub trees
        let left = self.recursive_build(&mut indices[..median], depth + 1);
        let right = self.recursive_build(&mut indices[median + 1..], depth + 1);
        // create a new Node and return this Node's index
        let node_index = self.nodes.len();
        self.nodes
            .push(Node::new(split_axis, point_index, left, right));
        Some(node_index)
    }

    /// Recursively search the K-d tree.
    fn recursive_search(
        &self,
        node_index: usize,
        query: &[T],
        radius_sq: f64,
        results: &mut Vec<usize>,
    ) {
        // get the current node's distance from the query point and add this
        // point if we're within the radius squared
        let node = &self.nodes[node_index];
        let node_dist_sq = query.iter().enumerate().fold(0.0, |acc, (i, &q)| {
            let d = self.cloud[[node.point_index, i]].to_f64() - q.to_f64();
            acc + d * d
        });
        if node_dist_sq <= radius_sq {
            results.push(node.point_index);
        }
        // decide the transveral order and recurse into the near side and far
        // side (only if needed)
        let ax = node.split_axis;
        let diff = query[ax].to_f64() - self.cloud[[node.point_index, ax]].to_f64();
        let (near, far) = if diff <= 0.0 {
            (node.left, node.right)
        } else {
            (node.right, node.left)
        };
        if let Some(child) = near {
            self.recursive_search(child, query, radius_sq, results);
        }
        if diff * diff <= radius_sq
            && let Some(child) = far
        {
            self.recursive_search(child, query, radius_sq, results);
        }
    }
}

impl Node {
    /// Creates a new K-d tree node.
    pub fn new(
        split_axis: usize,
        point_index: usize,
        left: Option<usize>,
        right: Option<usize>,
    ) -> Self {
        Self {
            split_axis,
            point_index,
            left,
            right,
        }
    }
}