kd-tree 0.6.2

k-dimensional tree
Documentation
use crate::{ItemAndDistance, KdPoint};

pub fn kd_nearests<'a, T: KdPoint>(
    kdtree: &'a [T],
    query: &impl KdPoint<Scalar = T::Scalar, Dim = T::Dim>,
    num: usize,
) -> Vec<ItemAndDistance<'a, T, T::Scalar>> {
    kd_nearests_by(kdtree, query, num, |item, k| item.at(k))
}

pub fn kd_nearests_by<'a, T, P: KdPoint>(
    kdtree: &'a [T],
    query: &P,
    num: usize,
    get: impl Fn(&T, usize) -> P::Scalar + Copy,
) -> Vec<ItemAndDistance<'a, T, P::Scalar>> {
    fn distance_squared<P: KdPoint, T>(
        p1: &P,
        p2: &T,
        get: impl Fn(&T, usize) -> P::Scalar,
    ) -> P::Scalar {
        let mut squared_distance = <P::Scalar as num_traits::Zero>::zero();
        for i in 0..P::dim() {
            let diff = p1.at(i) - get(p2, i);
            squared_distance += diff * diff;
        }
        squared_distance
    }
    fn recurse<'a, T, Q: KdPoint>(
        nearests: &mut Vec<ItemAndDistance<'a, T, Q::Scalar>>,
        kdtree: &'a [T],
        get: impl Fn(&T, usize) -> Q::Scalar + Copy,
        query: &Q,
        axis: usize,
    ) {
        let mid_idx = kdtree.len() / 2;
        let item = &kdtree[mid_idx];
        let squared_distance = distance_squared(query, item, get);
        if nearests.len() < nearests.capacity()
            || squared_distance < nearests.last().unwrap().squared_distance
        {
            if nearests.len() == nearests.capacity() {
                nearests.pop();
            }
            let i = nearests
                .binary_search_by(|item| {
                    item.squared_distance
                        .partial_cmp(&squared_distance)
                        .unwrap_or(std::cmp::Ordering::Equal)
                })
                .unwrap_or_else(|i| i);
            nearests.insert(
                i,
                ItemAndDistance {
                    item,
                    squared_distance,
                },
            );
        }
        let mid_pos = get(item, axis);
        let [branch1, branch2] = if query.at(axis) < mid_pos {
            [&kdtree[..mid_idx], &kdtree[mid_idx + 1..]]
        } else {
            [&kdtree[mid_idx + 1..], &kdtree[..mid_idx]]
        };
        if !branch1.is_empty() {
            recurse(nearests, branch1, get, query, (axis + 1) % Q::dim());
        }
        if !branch2.is_empty() {
            let diff = query.at(axis) - mid_pos;
            if diff * diff < nearests.last().unwrap().squared_distance {
                recurse(nearests, branch2, get, query, (axis + 1) % Q::dim());
            }
        }
    }
    if num == 0 || kdtree.is_empty() {
        return Vec::new();
    }
    let mut nearests = Vec::with_capacity(num);
    recurse(&mut nearests, kdtree, get, query, 0);
    nearests
}