Skip to main content

oxirs_vec/
tree_indices_kdtree.rs

1//! KD-Tree implementation for nearest-neighbor search.
2//!
3//! A KD-tree is a classic space-partitioning binary tree that splits points
4//! along alternating axis-aligned dimensions.
5
6use crate::tree_indices_types::{SearchResult, TreeIndexConfig};
7use crate::Vector;
8use anyhow::Result;
9use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11
12/// KD-Tree implementation
13pub struct KdTree {
14    pub(crate) root: Option<Box<KdNode>>,
15    pub(crate) data: Vec<(String, Vector)>,
16    pub(crate) config: TreeIndexConfig,
17}
18
19pub(crate) struct KdNode {
20    /// Split dimension
21    split_dim: usize,
22    /// Split value
23    split_value: f32,
24    /// Left child (values <= split_value)
25    left: Option<Box<KdNode>>,
26    /// Right child (values > split_value)
27    right: Option<Box<KdNode>>,
28    /// Indices for leaf nodes
29    indices: Vec<usize>,
30}
31
32impl KdTree {
33    pub fn new(config: TreeIndexConfig) -> Self {
34        Self {
35            root: None,
36            data: Vec::new(),
37            config,
38        }
39    }
40
41    pub fn build(&mut self) -> Result<()> {
42        if self.data.is_empty() {
43            return Ok(());
44        }
45
46        let indices: Vec<usize> = (0..self.data.len()).collect();
47        let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
48
49        self.root = Some(Box::new(self.build_node(&points, indices, 0)?));
50        Ok(())
51    }
52
53    fn build_node(&self, points: &[Vec<f32>], indices: Vec<usize>, depth: usize) -> Result<KdNode> {
54        // Reasonable stack overflow prevention with proper depth limit
55        let max_depth = if !self.data.is_empty() {
56            ((self.data.len() as f32).log2() * 2.0) as usize + 10
57        } else {
58            50
59        };
60
61        if indices.len() <= self.config.max_leaf_size || indices.len() <= 1 || depth >= max_depth {
62            return Ok(KdNode {
63                split_dim: 0,
64                split_value: 0.0,
65                left: None,
66                right: None,
67                indices,
68            });
69        }
70
71        let dimensions = points[0].len();
72        let split_dim = depth % dimensions;
73
74        // Find median along split dimension
75        let mut values: Vec<(f32, usize)> = indices
76            .iter()
77            .map(|&idx| (points[idx][split_dim], idx))
78            .collect();
79
80        values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
81
82        let median_idx = values.len() / 2;
83        let split_value = values[median_idx].0;
84
85        let left_indices: Vec<usize> = values[..median_idx].iter().map(|(_, idx)| *idx).collect();
86
87        let right_indices: Vec<usize> = values[median_idx..].iter().map(|(_, idx)| *idx).collect();
88
89        // Prevent creating empty partitions - create leaf instead
90        if left_indices.is_empty() || right_indices.is_empty() {
91            return Ok(KdNode {
92                split_dim: 0,
93                split_value: 0.0,
94                left: None,
95                right: None,
96                indices,
97            });
98        }
99
100        let left = Some(Box::new(self.build_node(
101            points,
102            left_indices,
103            depth + 1,
104        )?));
105
106        let right = Some(Box::new(self.build_node(
107            points,
108            right_indices,
109            depth + 1,
110        )?));
111
112        Ok(KdNode {
113            split_dim,
114            split_value,
115            left,
116            right,
117            indices: Vec::new(),
118        })
119    }
120
121    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
122        if self.root.is_none() {
123            return Vec::new();
124        }
125
126        let mut heap = BinaryHeap::new();
127        self.search_node(
128            self.root
129                .as_ref()
130                .expect("tree should have root after build"),
131            query,
132            k,
133            &mut heap,
134        );
135
136        let mut results: Vec<(usize, f32)> =
137            heap.into_iter().map(|r| (r.index, r.distance)).collect();
138
139        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
140        results
141    }
142
143    fn search_node(
144        &self,
145        node: &KdNode,
146        query: &[f32],
147        k: usize,
148        heap: &mut BinaryHeap<SearchResult>,
149    ) {
150        if !node.indices.is_empty() {
151            // Leaf node
152            for &idx in &node.indices {
153                let point = &self.data[idx].1.as_f32();
154                let dist = self.config.distance_metric.distance(query, point);
155
156                if heap.len() < k {
157                    heap.push(SearchResult {
158                        index: idx,
159                        distance: dist,
160                    });
161                } else if dist < heap.peek().expect("heap should have k elements").distance {
162                    heap.pop();
163                    heap.push(SearchResult {
164                        index: idx,
165                        distance: dist,
166                    });
167                }
168            }
169            return;
170        }
171
172        // Determine which side to search first
173        let go_left = query[node.split_dim] <= node.split_value;
174
175        let (first, second) = if go_left {
176            (&node.left, &node.right)
177        } else {
178            (&node.right, &node.left)
179        };
180
181        // Search the nearer side first
182        if let Some(child) = first {
183            self.search_node(child, query, k, heap);
184        }
185
186        // Check if we need to search the other side
187        if heap.len() < k || {
188            let split_dist = (query[node.split_dim] - node.split_value).abs();
189            split_dist < heap.peek().expect("heap should have k elements").distance
190        } {
191            if let Some(child) = second {
192                self.search_node(child, query, k, heap);
193            }
194        }
195    }
196}