Skip to main content

oxihuman_core/
kd_tree.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3#![allow(dead_code)]
4
5//! k-d tree for nearest neighbor search in 2D and 3D.
6
7/// A 3D point with an associated ID.
8#[derive(Debug, Clone, Copy, PartialEq)]
9pub struct KdPoint3 {
10    pub pos: [f32; 3],
11    pub id: usize,
12}
13
14impl KdPoint3 {
15    pub fn new(x: f32, y: f32, z: f32, id: usize) -> Self {
16        KdPoint3 { pos: [x, y, z], id }
17    }
18
19    fn dist_sq(&self, other: &[f32; 3]) -> f32 {
20        (0..3).map(|i| (self.pos[i] - other[i]).powi(2)).sum()
21    }
22}
23
24/// A k-d tree node.
25#[derive(Debug)]
26struct KdNode {
27    point: KdPoint3,
28    left: Option<Box<KdNode>>,
29    right: Option<Box<KdNode>>,
30    axis: usize,
31}
32
33/// k-d tree for 3D nearest neighbor queries.
34#[derive(Default)]
35pub struct KdTree3 {
36    root: Option<Box<KdNode>>,
37    count: usize,
38}
39
40fn build(points: &mut [KdPoint3], depth: usize) -> Option<Box<KdNode>> {
41    if points.is_empty() {
42        return None;
43    }
44    let axis = depth % 3;
45    points.sort_by(|a, b| {
46        a.pos[axis]
47            .partial_cmp(&b.pos[axis])
48            .unwrap_or(std::cmp::Ordering::Equal)
49    });
50    let mid = points.len() / 2;
51    Some(Box::new(KdNode {
52        point: points[mid],
53        axis,
54        left: build(&mut points[..mid], depth + 1),
55        right: build(&mut points[mid + 1..], depth + 1),
56    }))
57}
58
59fn nn_search<'a>(node: &'a KdNode, query: &[f32; 3], best: &mut Option<(f32, &'a KdPoint3)>) {
60    let d = node.point.dist_sq(query);
61    if best.is_none_or(|(bd, _)| d < bd) {
62        *best = Some((d, &node.point));
63    }
64    let axis = node.axis;
65    let diff = query[axis] - node.point.pos[axis];
66    let (near, far) = if diff <= 0.0 {
67        (node.left.as_deref(), node.right.as_deref())
68    } else {
69        (node.right.as_deref(), node.left.as_deref())
70    };
71    if let Some(n) = near {
72        nn_search(n, query, best);
73    }
74    if let Some(f) = far {
75        let best_d = best.map(|(bd, _)| bd).unwrap_or(f32::INFINITY);
76        if diff * diff <= best_d {
77            nn_search(f, query, best);
78        }
79    }
80}
81
82fn range_search(node: &KdNode, query: &[f32; 3], r_sq: f32, result: &mut Vec<KdPoint3>) {
83    if node.point.dist_sq(query) <= r_sq {
84        result.push(node.point);
85    }
86    let axis = node.axis;
87    let diff = query[axis] - node.point.pos[axis];
88    if diff - r_sq.sqrt() <= 0.0 {
89        if let Some(n) = &node.left {
90            range_search(n, query, r_sq, result);
91        }
92    }
93    if diff + r_sq.sqrt() >= 0.0 {
94        if let Some(n) = &node.right {
95            range_search(n, query, r_sq, result);
96        }
97    }
98}
99
100impl KdTree3 {
101    /// Build a k-d tree from a list of points.
102    pub fn build(mut points: Vec<KdPoint3>) -> Self {
103        let count = points.len();
104        let root = build(&mut points, 0);
105        KdTree3 { root, count }
106    }
107
108    /// Nearest neighbor query. Returns (point, distance).
109    pub fn nearest(&self, query: &[f32; 3]) -> Option<(KdPoint3, f32)> {
110        let root = self.root.as_deref()?;
111        let mut best = None;
112        nn_search(root, query, &mut best);
113        best.map(|(d, p)| (*p, d.sqrt()))
114    }
115
116    /// Range query: all points within radius `r`.
117    pub fn range_query(&self, query: &[f32; 3], r: f32) -> Vec<KdPoint3> {
118        let mut result = Vec::new();
119        if let Some(root) = &self.root {
120            range_search(root, query, r * r, &mut result);
121        }
122        result
123    }
124
125    /// Number of points in the tree.
126    pub fn len(&self) -> usize {
127        self.count
128    }
129
130    /// True if empty.
131    pub fn is_empty(&self) -> bool {
132        self.count == 0
133    }
134}
135
136/// Build a k-d tree from xyz arrays.
137pub fn new_kd_tree(positions: &[[f32; 3]]) -> KdTree3 {
138    let points: Vec<KdPoint3> = positions
139        .iter()
140        .enumerate()
141        .map(|(i, p)| KdPoint3 { pos: *p, id: i })
142        .collect();
143    KdTree3::build(points)
144}
145
146/// Build a 2D k-d tree (z = 0).
147pub fn new_kd_tree_2d(positions: &[[f32; 2]]) -> KdTree3 {
148    let points: Vec<KdPoint3> = positions
149        .iter()
150        .enumerate()
151        .map(|(i, p)| KdPoint3 {
152            pos: [p[0], p[1], 0.0],
153            id: i,
154        })
155        .collect();
156    KdTree3::build(points)
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn test_nearest_basic() {
165        let pts = vec![
166            KdPoint3::new(0.0, 0.0, 0.0, 0),
167            KdPoint3::new(1.0, 0.0, 0.0, 1),
168            KdPoint3::new(5.0, 5.0, 5.0, 2),
169        ];
170        let tree = KdTree3::build(pts);
171        let (p, d) = tree.nearest(&[0.1, 0.0, 0.0]).expect("should succeed");
172        assert_eq!(p.id, 0);
173        assert!(d < 0.2);
174    }
175
176    #[test]
177    fn test_nearest_single_point() {
178        let pts = vec![KdPoint3::new(3.0, 4.0, 0.0, 0)];
179        let tree = KdTree3::build(pts);
180        let (p, d) = tree.nearest(&[3.0, 4.0, 0.0]).expect("should succeed");
181        assert_eq!(p.id, 0);
182        assert!(d < 1e-5);
183    }
184
185    #[test]
186    fn test_empty_tree() {
187        let tree = KdTree3::build(vec![]);
188        assert!(tree.nearest(&[0.0, 0.0, 0.0]).is_none());
189        assert!(tree.is_empty());
190    }
191
192    #[test]
193    fn test_range_query() {
194        let pts: Vec<KdPoint3> = (0..10)
195            .map(|i| KdPoint3::new(i as f32, 0.0, 0.0, i))
196            .collect();
197        let tree = KdTree3::build(pts);
198        let found = tree.range_query(&[5.0, 0.0, 0.0], 2.5);
199        /* Should include 3,4,5,6,7 */
200        assert!(found.len() >= 4);
201    }
202
203    #[test]
204    fn test_new_kd_tree() {
205        let pos = vec![[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]];
206        let tree = new_kd_tree(&pos);
207        assert_eq!(tree.len(), 2);
208    }
209
210    #[test]
211    fn test_new_kd_tree_2d() {
212        let pos = vec![[0.0f32, 0.0], [1.0, 1.0], [2.0, 2.0]];
213        let tree = new_kd_tree_2d(&pos);
214        let (p, _) = tree.nearest(&[0.9, 0.9, 0.0]).expect("should succeed");
215        assert_eq!(p.id, 1);
216    }
217
218    #[test]
219    fn test_len() {
220        let pts: Vec<KdPoint3> = (0..5)
221            .map(|i| KdPoint3::new(i as f32, 0.0, 0.0, i))
222            .collect();
223        let tree = KdTree3::build(pts);
224        assert_eq!(tree.len(), 5);
225    }
226
227    #[test]
228    fn test_many_points_nearest() {
229        let pts: Vec<KdPoint3> = (0..100)
230            .map(|i| KdPoint3::new(i as f32, 0.0, 0.0, i))
231            .collect();
232        let tree = KdTree3::build(pts);
233        let (p, d) = tree.nearest(&[49.5, 0.0, 0.0]).expect("should succeed");
234        assert!(p.id == 49 || p.id == 50);
235        assert!(d < 1.0);
236    }
237}