Skip to main content

oxihuman_core/
kd_tree_2d.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3#![allow(dead_code)]
4
5//! 2D k-d tree for nearest-neighbor queries.
6
7/// A 2D point with an optional payload index.
8#[derive(Debug, Clone, Copy)]
9#[allow(dead_code)]
10pub struct KdPoint2 {
11    pub x: f32,
12    pub y: f32,
13    pub id: usize,
14}
15
16impl KdPoint2 {
17    #[allow(dead_code)]
18    pub fn new(x: f32, y: f32, id: usize) -> Self {
19        Self { x, y, id }
20    }
21
22    fn dist_sq(self, other: [f32; 2]) -> f32 {
23        let dx = self.x - other[0];
24        let dy = self.y - other[1];
25        dx * dx + dy * dy
26    }
27}
28
29/// Internal node of a 2D k-d tree.
30#[derive(Debug, Clone)]
31#[allow(dead_code)]
32enum KdNode2 {
33    Leaf(KdPoint2),
34    Split {
35        axis: usize,
36        split_val: f32,
37        left: Box<KdNode2>,
38        right: Box<KdNode2>,
39    },
40}
41
42/// A 2D k-d tree.
43#[derive(Debug, Clone, Default)]
44#[allow(dead_code)]
45pub struct KdTree2D {
46    root: Option<Box<KdNode2>>,
47    count: usize,
48}
49
50fn build_node(points: &mut [KdPoint2], depth: usize) -> Option<Box<KdNode2>> {
51    if points.is_empty() {
52        return None;
53    }
54    if points.len() == 1 {
55        return Some(Box::new(KdNode2::Leaf(points[0])));
56    }
57    let axis = depth % 2;
58    if axis == 0 {
59        points.sort_by(|a, b| a.x.partial_cmp(&b.x).unwrap_or(std::cmp::Ordering::Equal));
60    } else {
61        points.sort_by(|a, b| a.y.partial_cmp(&b.y).unwrap_or(std::cmp::Ordering::Equal));
62    }
63    let mid = points.len() / 2;
64    let split_val = if axis == 0 {
65        points[mid].x
66    } else {
67        points[mid].y
68    };
69    let left = build_node(&mut points[..mid], depth + 1);
70    let right = build_node(&mut points[mid..], depth + 1);
71    let node = match (left, right) {
72        (Some(l), Some(r)) => KdNode2::Split {
73            axis,
74            split_val,
75            left: l,
76            right: r,
77        },
78        (Some(l), None) => *l,
79        (None, Some(r)) => *r,
80        (None, None) => unreachable!(),
81    };
82    Some(Box::new(node))
83}
84
85fn nn_search(node: &KdNode2, query: [f32; 2], best: &mut Option<(f32, KdPoint2)>) {
86    match node {
87        KdNode2::Leaf(p) => {
88            let d = p.dist_sq(query);
89            if best.is_none_or(|(bd, _)| d < bd) {
90                *best = Some((d, *p));
91            }
92        }
93        KdNode2::Split {
94            axis,
95            split_val,
96            left,
97            right,
98        } => {
99            let qval = if *axis == 0 { query[0] } else { query[1] };
100            let (near, far) = if qval < *split_val {
101                (left.as_ref(), right.as_ref())
102            } else {
103                (right.as_ref(), left.as_ref())
104            };
105            nn_search(near, query, best);
106            let plane_dist = (qval - split_val) * (qval - split_val);
107            if best.is_none_or(|(bd, _)| plane_dist < bd) {
108                nn_search(far, query, best);
109            }
110        }
111    }
112}
113
114impl KdTree2D {
115    /// Build a tree from a slice of points.
116    #[allow(dead_code)]
117    pub fn build(points: &[KdPoint2]) -> Self {
118        let mut pts = points.to_vec();
119        let n = pts.len();
120        Self {
121            root: build_node(&mut pts, 0),
122            count: n,
123        }
124    }
125
126    /// Nearest neighbor query. Returns the nearest point and its squared distance.
127    #[allow(dead_code)]
128    pub fn nearest(&self, query: [f32; 2]) -> Option<(KdPoint2, f32)> {
129        let root = self.root.as_ref()?;
130        let mut best = None;
131        nn_search(root, query, &mut best);
132        best.map(|(d, p)| (p, d))
133    }
134
135    /// Number of points.
136    #[allow(dead_code)]
137    pub fn len(&self) -> usize {
138        self.count
139    }
140
141    /// Returns true if there are no points.
142    #[allow(dead_code)]
143    pub fn is_empty(&self) -> bool {
144        self.count == 0
145    }
146}
147
148/// Build a `KdTree2D` from raw (x, y) pairs.
149#[allow(dead_code)]
150pub fn kd2_build(xys: &[[f32; 2]]) -> KdTree2D {
151    let pts: Vec<KdPoint2> = xys
152        .iter()
153        .enumerate()
154        .map(|(i, &[x, y])| KdPoint2::new(x, y, i))
155        .collect();
156    KdTree2D::build(&pts)
157}
158
159/// Nearest-neighbor distance (squared) for a query.
160#[allow(dead_code)]
161pub fn kd2_nn_dist_sq(tree: &KdTree2D, query: [f32; 2]) -> Option<f32> {
162    tree.nearest(query).map(|(_, d)| d)
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    fn sample_tree() -> KdTree2D {
170        kd2_build(&[[0.0, 0.0], [3.0, 0.0], [1.0, 2.0], [5.0, 5.0]])
171    }
172
173    #[test]
174    fn empty_tree_returns_none() {
175        let t = KdTree2D::build(&[]);
176        assert!(t.nearest([0.0, 0.0]).is_none());
177    }
178
179    #[test]
180    fn single_point_is_nearest() {
181        let t = kd2_build(&[[2.0, 3.0]]);
182        let (p, _) = t.nearest([0.0, 0.0]).expect("should succeed");
183        assert_eq!(p.id, 0);
184    }
185
186    #[test]
187    fn nearest_to_origin() {
188        let t = sample_tree();
189        let (p, d) = t.nearest([0.0, 0.0]).expect("should succeed");
190        assert_eq!(p.id, 0);
191        assert!(d < 1e-5);
192    }
193
194    #[test]
195    fn nearest_to_far_point() {
196        let t = sample_tree();
197        let (p, _) = t.nearest([5.0, 5.0]).expect("should succeed");
198        assert_eq!(p.id, 3);
199    }
200
201    #[test]
202    fn tree_len_matches_input() {
203        let t = sample_tree();
204        assert_eq!(t.len(), 4);
205    }
206
207    #[test]
208    fn is_empty_false_for_nonempty() {
209        let t = sample_tree();
210        assert!(!t.is_empty());
211    }
212
213    #[test]
214    fn is_empty_true_for_empty() {
215        let t = KdTree2D::build(&[]);
216        assert!(t.is_empty());
217    }
218
219    #[test]
220    fn nn_dist_sq_is_zero_for_exact_match() {
221        let t = kd2_build(&[[1.0, 1.0], [2.0, 2.0]]);
222        let d = kd2_nn_dist_sq(&t, [1.0, 1.0]).expect("should succeed");
223        assert!(d < 1e-6);
224    }
225
226    #[test]
227    fn nearest_among_two_picks_closer() {
228        let t = kd2_build(&[[0.0, 0.0], [10.0, 0.0]]);
229        let (p, _) = t.nearest([3.0, 0.0]).expect("should succeed");
230        assert_eq!(p.id, 0);
231    }
232
233    #[test]
234    fn build_from_raw_xy_assigns_ids() {
235        let t = kd2_build(&[[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]]);
236        assert_eq!(t.len(), 3);
237    }
238}