kd-tree 0.6.2

k-dimensional tree
Documentation
use crate::{ItemAndDistance, KdPoint};

pub fn kd_nearest<'a, T: KdPoint>(
    kdtree: &'a [T],
    query: &impl KdPoint<Scalar = T::Scalar, Dim = T::Dim>,
) -> ItemAndDistance<'a, T, T::Scalar> {
    kd_nearest_by(kdtree, query, |item, k| item.at(k))
}

pub fn kd_nearest_by<'a, T, P: KdPoint>(
    kdtree: &'a [T],
    query: &P,
    get: impl Fn(&T, usize) -> P::Scalar + Copy,
) -> ItemAndDistance<'a, T, P::Scalar> {
    fn distance_squared<P: KdPoint, T>(
        p1: &P,
        p2: &T,
        get: impl Fn(&T, usize) -> P::Scalar,
    ) -> P::Scalar {
        let mut squared_distance = <P::Scalar as num_traits::Zero>::zero();
        for i in 0..P::dim() {
            let diff = p1.at(i) - get(p2, i);
            squared_distance += diff * diff;
        }
        squared_distance
    }
    fn recurse<'a, T, Q: KdPoint>(
        nearest: &mut ItemAndDistance<'a, T, Q::Scalar>,
        kdtree: &'a [T],
        get: impl Fn(&T, usize) -> Q::Scalar + Copy,
        query: &Q,
        axis: usize,
    ) {
        let mid_idx = kdtree.len() / 2;
        let item = &kdtree[mid_idx];
        let squared_distance = distance_squared(query, item, get);
        if squared_distance < nearest.squared_distance {
            nearest.item = item;
            nearest.squared_distance = squared_distance;
            use num_traits::Zero;
            if nearest.squared_distance.is_zero() {
                return;
            }
        }
        let mid_pos = get(item, axis);
        let [branch1, branch2] = if query.at(axis) < mid_pos {
            [&kdtree[..mid_idx], &kdtree[mid_idx + 1..]]
        } else {
            [&kdtree[mid_idx + 1..], &kdtree[..mid_idx]]
        };
        if !branch1.is_empty() {
            recurse(nearest, branch1, get, query, (axis + 1) % Q::dim());
        }
        if !branch2.is_empty() {
            let diff = query.at(axis) - mid_pos;
            if diff * diff < nearest.squared_distance {
                recurse(nearest, branch2, get, query, (axis + 1) % Q::dim());
            }
        }
    }
    assert!(!kdtree.is_empty());
    let mut nearest = ItemAndDistance {
        item: &kdtree[0],
        squared_distance: distance_squared(query, &kdtree[0], get),
    };
    recurse(&mut nearest, kdtree, get, query, 0);
    nearest
}

#[allow(dead_code)]
pub fn kd_nearest_with<T, Scalar>(
    kdtree: &[T],
    dim: usize,
    kd_difference: impl Fn(&T, usize) -> Scalar + Copy,
) -> ItemAndDistance<'_, T, Scalar>
where
    Scalar: num_traits::NumAssign + Copy + PartialOrd,
{
    fn squared_distance<T, Scalar: num_traits::NumAssign + Copy>(
        item: &T,
        dim: usize,
        kd_difference: impl Fn(&T, usize) -> Scalar + Copy,
    ) -> Scalar {
        let mut squared_distance = Scalar::zero();
        for k in 0..dim {
            let diff = kd_difference(item, k);
            squared_distance += diff * diff;
        }
        squared_distance
    }
    fn recurse<'a, T, Scalar>(
        nearest: &mut ItemAndDistance<'a, T, Scalar>,
        kdtree: &'a [T],
        axis: usize,
        dim: usize,
        kd_difference: impl Fn(&T, usize) -> Scalar + Copy,
    ) where
        Scalar: num_traits::NumAssign + Copy + PartialOrd,
    {
        let mid_idx = kdtree.len() / 2;
        let mid = &kdtree[mid_idx];
        let squared_distance = squared_distance(mid, dim, kd_difference);
        if squared_distance < nearest.squared_distance {
            *nearest = ItemAndDistance {
                item: mid,
                squared_distance,
            };
            if nearest.squared_distance.is_zero() {
                return;
            }
        }
        let [branch1, branch2] = if kd_difference(mid, axis) < Scalar::zero() {
            [&kdtree[..mid_idx], &kdtree[mid_idx + 1..]]
        } else {
            [&kdtree[mid_idx + 1..], &kdtree[..mid_idx]]
        };
        if !branch1.is_empty() {
            recurse(nearest, branch1, (axis + 1) % dim, dim, kd_difference);
        }
        if !branch2.is_empty() {
            let diff = kd_difference(mid, axis);
            if diff * diff < nearest.squared_distance {
                recurse(nearest, branch2, (axis + 1) % dim, dim, kd_difference);
            }
        }
    }
    assert!(!kdtree.is_empty());
    let mut nearest = ItemAndDistance {
        item: &kdtree[0],
        squared_distance: squared_distance(&kdtree[0], dim, kd_difference),
    };
    recurse(&mut nearest, kdtree, 0, dim, kd_difference);
    nearest
}