kd_tree/
nearest.rs

1use crate::{ItemAndDistance, KdPoint};
2
3pub fn kd_nearest<'a, T: KdPoint>(
4    kdtree: &'a [T],
5    query: &impl KdPoint<Scalar = T::Scalar, Dim = T::Dim>,
6) -> ItemAndDistance<'a, T, T::Scalar> {
7    kd_nearest_by(kdtree, query, |item, k| item.at(k))
8}
9
10pub fn kd_nearest_by<'a, T, P: KdPoint>(
11    kdtree: &'a [T],
12    query: &P,
13    get: impl Fn(&T, usize) -> P::Scalar + Copy,
14) -> ItemAndDistance<'a, T, P::Scalar> {
15    fn distance_squared<P: KdPoint, T>(
16        p1: &P,
17        p2: &T,
18        get: impl Fn(&T, usize) -> P::Scalar,
19    ) -> P::Scalar {
20        let mut squared_distance = <P::Scalar as num_traits::Zero>::zero();
21        for i in 0..P::dim() {
22            let diff = p1.at(i) - get(p2, i);
23            squared_distance += diff * diff;
24        }
25        squared_distance
26    }
27    fn recurse<'a, T, Q: KdPoint>(
28        nearest: &mut ItemAndDistance<'a, T, Q::Scalar>,
29        kdtree: &'a [T],
30        get: impl Fn(&T, usize) -> Q::Scalar + Copy,
31        query: &Q,
32        axis: usize,
33    ) {
34        let mid_idx = kdtree.len() / 2;
35        let item = &kdtree[mid_idx];
36        let squared_distance = distance_squared(query, item, get);
37        if squared_distance < nearest.squared_distance {
38            nearest.item = item;
39            nearest.squared_distance = squared_distance;
40            use num_traits::Zero;
41            if nearest.squared_distance.is_zero() {
42                return;
43            }
44        }
45        let mid_pos = get(item, axis);
46        let [branch1, branch2] = if query.at(axis) < mid_pos {
47            [&kdtree[..mid_idx], &kdtree[mid_idx + 1..]]
48        } else {
49            [&kdtree[mid_idx + 1..], &kdtree[..mid_idx]]
50        };
51        if !branch1.is_empty() {
52            recurse(nearest, branch1, get, query, (axis + 1) % Q::dim());
53        }
54        if !branch2.is_empty() {
55            let diff = query.at(axis) - mid_pos;
56            if diff * diff < nearest.squared_distance {
57                recurse(nearest, branch2, get, query, (axis + 1) % Q::dim());
58            }
59        }
60    }
61    assert!(!kdtree.is_empty());
62    let mut nearest = ItemAndDistance {
63        item: &kdtree[0],
64        squared_distance: distance_squared(query, &kdtree[0], get),
65    };
66    recurse(&mut nearest, kdtree, get, query, 0);
67    nearest
68}
69
70#[allow(dead_code)]
71pub fn kd_nearest_with<T, Scalar>(
72    kdtree: &[T],
73    dim: usize,
74    kd_difference: impl Fn(&T, usize) -> Scalar + Copy,
75) -> ItemAndDistance<'_, T, Scalar>
76where
77    Scalar: num_traits::NumAssign + Copy + PartialOrd,
78{
79    fn squared_distance<T, Scalar: num_traits::NumAssign + Copy>(
80        item: &T,
81        dim: usize,
82        kd_difference: impl Fn(&T, usize) -> Scalar + Copy,
83    ) -> Scalar {
84        let mut squared_distance = Scalar::zero();
85        for k in 0..dim {
86            let diff = kd_difference(item, k);
87            squared_distance += diff * diff;
88        }
89        squared_distance
90    }
91    fn recurse<'a, T, Scalar>(
92        nearest: &mut ItemAndDistance<'a, T, Scalar>,
93        kdtree: &'a [T],
94        axis: usize,
95        dim: usize,
96        kd_difference: impl Fn(&T, usize) -> Scalar + Copy,
97    ) where
98        Scalar: num_traits::NumAssign + Copy + PartialOrd,
99    {
100        let mid_idx = kdtree.len() / 2;
101        let mid = &kdtree[mid_idx];
102        let squared_distance = squared_distance(mid, dim, kd_difference);
103        if squared_distance < nearest.squared_distance {
104            *nearest = ItemAndDistance {
105                item: mid,
106                squared_distance,
107            };
108            if nearest.squared_distance.is_zero() {
109                return;
110            }
111        }
112        let [branch1, branch2] = if kd_difference(mid, axis) < Scalar::zero() {
113            [&kdtree[..mid_idx], &kdtree[mid_idx + 1..]]
114        } else {
115            [&kdtree[mid_idx + 1..], &kdtree[..mid_idx]]
116        };
117        if !branch1.is_empty() {
118            recurse(nearest, branch1, (axis + 1) % dim, dim, kd_difference);
119        }
120        if !branch2.is_empty() {
121            let diff = kd_difference(mid, axis);
122            if diff * diff < nearest.squared_distance {
123                recurse(nearest, branch2, (axis + 1) % dim, dim, kd_difference);
124            }
125        }
126    }
127    assert!(!kdtree.is_empty());
128    let mut nearest = ItemAndDistance {
129        item: &kdtree[0],
130        squared_distance: squared_distance(&kdtree[0], dim, kd_difference),
131    };
132    recurse(&mut nearest, kdtree, 0, dim, kd_difference);
133    nearest
134}