linfa-nn 0.8.1

A collection of nearest neighbour algorithms
Documentation
use approx::assert_abs_diff_eq;
use ndarray::{arr1, arr2, aview1, stack, Array2, ArrayBase, ArrayView1, Axis, Dim, ViewRepr};
use ndarray_rand::{rand::SeedableRng, rand_distr::Uniform, RandomExt};
use ndarray_stats::DeviationExt;
use noisy_float::{checkers::FiniteChecker, NoisyFloat};
use rand_xoshiro::Xoshiro256Plus;

use linfa_nn::{distance::*, CommonNearestNeighbour, LinearSearch, NearestNeighbour};

fn sort_by_dist<'a>(
    mut vec: Vec<(ArrayView1<'a, f64>, usize)>,
    pt: ArrayView1<f64>,
) -> Vec<(ArrayView1<'a, f64>, usize)> {
    vec.sort_by_key(|v| NoisyFloat::<_, FiniteChecker>::new(v.0.sq_l2_dist(&pt).unwrap()));
    vec
}

fn assert_query(
    output: Vec<(ArrayView1<f64>, usize)>,
    input_data: &Array2<f64>,
    exp_pos: Vec<usize>,
) {
    let (pts, pos): (Vec<_>, Vec<_>) = output.into_iter().unzip();
    assert_eq!(pos, exp_pos);
    assert_abs_diff_eq!(
        stack(Axis(0), &pts).unwrap(),
        input_data.select(Axis(0), &exp_pos)
    );
}

fn nn_test_empty(builder: &CommonNearestNeighbour) {
    let points = Array2::zeros((0, 2));
    let nn = builder.from_batch(&points, L2Dist).unwrap();

    let out = nn.k_nearest(aview1(&[0.0, 1.0]), 2).unwrap();
    assert_eq!(out, Vec::<_>::new());

    let out = nn.k_nearest(aview1(&[4.0, 4.0]), 3).unwrap();
    assert_eq!(out, Vec::<_>::new());

    let pt = aview1(&[6.0, 3.0]);
    let out = nn.within_range(pt, 9.0).unwrap();
    assert_eq!(out, Vec::<_>::new());
}

fn nn_test_error(builder: &CommonNearestNeighbour) {
    let points = Array2::<f64>::zeros((4, 0));
    assert!(builder.from_batch(&points, L2Dist).is_err());

    let points = arr2(&[[0.0, 2.0]]);
    assert!(builder
        .from_batch_with_leaf_size(&points, 0, L2Dist)
        .is_err());
    let nn = builder.from_batch(&points, L2Dist).unwrap();
    assert!(nn.k_nearest(aview1(&[]), 2).is_err());
    assert!(nn.within_range(aview1(&[2.2, 4.4, 5.5]), 4.0).is_err());
}

fn nn_test(builder: &CommonNearestNeighbour, sort_within_range: bool) {
    let points = arr2(&[[0.0, 2.0], [10.0, 4.0], [4.0, 5.0], [7.0, 1.0], [1.0, 7.2]]);
    let nn = builder.from_batch(&points, L2Dist).unwrap();

    let out = nn.k_nearest(aview1(&[0.0, 1.0]), 2).unwrap();
    assert_query(out, &points, vec![0, 2]);

    let out = nn.k_nearest(aview1(&[4.0, 4.0]), 3).unwrap();
    assert_query(out, &points, vec![2, 3, 4]);

    let out = nn.k_nearest(aview1(&[4.0, 4.0]), 10).unwrap();
    assert_query(out, &points, vec![2, 3, 4, 0, 1]);

    let pt = aview1(&[6.0, 3.0]);
    let mut out = nn.within_range(pt, 4.3).unwrap();
    if sort_within_range {
        out = sort_by_dist(out, pt);
    }
    assert_query(out, &points, vec![3, 2, 1]);
}

fn nn_test_degenerate(builder: &CommonNearestNeighbour) {
    let points = arr2(&[[0.0, 2.0], [0.0, 2.0], [0.0, 2.0], [0.0, 2.0], [0.0, 2.0]]);
    let nn = builder.from_batch(&points, L2Dist).unwrap();

    let out = nn
        .k_nearest(aview1(&[0.0, 1.0]), 2)
        .unwrap()
        .into_iter()
        .map(|(p, _)| p.reborrow())
        .collect::<Vec<_>>();
    assert_abs_diff_eq!(
        stack(Axis(0), &out).unwrap(),
        arr2(&[[0.0, 2.0], [0.0, 2.0]])
    );

    let out = nn
        .k_nearest(aview1(&[4.0, 4.0]), 3)
        .unwrap()
        .into_iter()
        .map(|(p, _)| p.reborrow())
        .collect::<Vec<_>>();
    assert_abs_diff_eq!(
        stack(Axis(0), &out).unwrap(),
        arr2(&[[0.0, 2.0], [0.0, 2.0], [0.0, 2.0]])
    );

    let pt = aview1(&[3.0, 2.0]);
    let out = nn
        .within_range(pt, 1.0)
        .unwrap()
        .into_iter()
        .map(|(p, _)| p.reborrow())
        .collect::<Vec<_>>();
    assert_eq!(
        out,
        Vec::<ArrayBase<ViewRepr<&f64>, Dim<[usize; 1]>>>::new()
    );
    let out = nn
        .within_range(pt, 20.0)
        .unwrap()
        .into_iter()
        .map(|(p, _)| p.reborrow())
        .collect::<Vec<_>>();
    assert_abs_diff_eq!(stack(Axis(0), &out).unwrap(), points);
}

fn assert_eq_queries(out1: Vec<(ArrayView1<f64>, usize)>, out2: Vec<(ArrayView1<f64>, usize)>) {
    let (pts1, pos1): (Vec<_>, Vec<_>) = out1.into_iter().unzip();
    let (pts2, pos2): (Vec<_>, Vec<_>) = out2.into_iter().unzip();
    assert_eq!(pos1, pos2);
    assert_abs_diff_eq!(
        stack(Axis(0), &pts1).unwrap(),
        stack(Axis(0), &pts2).unwrap(),
    );
}

fn nn_test_random<D: 'static + Distance<f64> + Clone>(
    builder: &CommonNearestNeighbour,
    dist_fn: D,
) {
    let mut rng = Xoshiro256Plus::seed_from_u64(40);
    let n_points = 50000;
    let n_features = 3;
    let points = Array2::random_using((n_points, n_features), Uniform::new(-50., 50.), &mut rng);

    let linear = LinearSearch::new()
        .from_batch(&points, dist_fn.clone())
        .unwrap();

    let nn = builder.from_batch(&points, dist_fn).unwrap();

    let pt = arr1(&[0., 0., 0.]);
    assert_eq_queries(
        nn.k_nearest(pt.view(), 5).unwrap(),
        linear.k_nearest(pt.view(), 5).unwrap(),
    );
    assert_eq_queries(
        sort_by_dist(nn.within_range(pt.view(), 15.0).unwrap(), pt.view()),
        sort_by_dist(linear.within_range(pt.view(), 15.0).unwrap(), pt.view()),
    );

    let pt = arr1(&[-3.4, 10., 0.95]);
    assert_eq_queries(
        nn.k_nearest(pt.view(), 30).unwrap(),
        linear.k_nearest(pt.view(), 30).unwrap(),
    );
    assert_eq_queries(
        sort_by_dist(nn.within_range(pt.view(), 25.0).unwrap(), pt.view()),
        sort_by_dist(linear.within_range(pt.view(), 25.0).unwrap(), pt.view()),
    );
}

macro_rules! nn_tests {
    ($mod:ident, $builder:ident, $sort:expr $(, $_u:ident)?) => {
        mod $mod {
            use super::*;

            #[test]
            fn empty() {
                nn_test_empty(&CommonNearestNeighbour::$builder);
            }

            #[test]
            fn error() {
                nn_test_error(&CommonNearestNeighbour::$builder);
            }

            #[test]
            fn normal() {
                nn_test(&CommonNearestNeighbour::$builder, $sort);
            }

            #[test]
            fn degenerate() {
                nn_test_degenerate(&CommonNearestNeighbour::$builder);
            }

            $(
                #[test]
                fn random_l2() {
                    let $_u: () = ();
                    nn_test_random(&CommonNearestNeighbour::$builder, L2Dist);
                }

                #[test]
                fn random_l1() {
                    nn_test_random(&CommonNearestNeighbour::$builder, L1Dist);
                }
            )?
        }
    };
}

nn_tests!(linear_search, LinearSearch, true);
nn_tests!(kdtree, KdTree, false, _u);
nn_tests!(balltree, BallTree, false, _u);