oxirs_vec/
tree_indices_kdtree.rs1use crate::tree_indices_types::{SearchResult, TreeIndexConfig};
7use crate::Vector;
8use anyhow::Result;
9use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11
12pub struct KdTree {
14 pub(crate) root: Option<Box<KdNode>>,
15 pub(crate) data: Vec<(String, Vector)>,
16 pub(crate) config: TreeIndexConfig,
17}
18
19pub(crate) struct KdNode {
20 split_dim: usize,
22 split_value: f32,
24 left: Option<Box<KdNode>>,
26 right: Option<Box<KdNode>>,
28 indices: Vec<usize>,
30}
31
32impl KdTree {
33 pub fn new(config: TreeIndexConfig) -> Self {
34 Self {
35 root: None,
36 data: Vec::new(),
37 config,
38 }
39 }
40
41 pub fn build(&mut self) -> Result<()> {
42 if self.data.is_empty() {
43 return Ok(());
44 }
45
46 let indices: Vec<usize> = (0..self.data.len()).collect();
47 let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
48
49 self.root = Some(Box::new(self.build_node(&points, indices, 0)?));
50 Ok(())
51 }
52
53 fn build_node(&self, points: &[Vec<f32>], indices: Vec<usize>, depth: usize) -> Result<KdNode> {
54 let max_depth = if !self.data.is_empty() {
56 ((self.data.len() as f32).log2() * 2.0) as usize + 10
57 } else {
58 50
59 };
60
61 if indices.len() <= self.config.max_leaf_size || indices.len() <= 1 || depth >= max_depth {
62 return Ok(KdNode {
63 split_dim: 0,
64 split_value: 0.0,
65 left: None,
66 right: None,
67 indices,
68 });
69 }
70
71 let dimensions = points[0].len();
72 let split_dim = depth % dimensions;
73
74 let mut values: Vec<(f32, usize)> = indices
76 .iter()
77 .map(|&idx| (points[idx][split_dim], idx))
78 .collect();
79
80 values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
81
82 let median_idx = values.len() / 2;
83 let split_value = values[median_idx].0;
84
85 let left_indices: Vec<usize> = values[..median_idx].iter().map(|(_, idx)| *idx).collect();
86
87 let right_indices: Vec<usize> = values[median_idx..].iter().map(|(_, idx)| *idx).collect();
88
89 if left_indices.is_empty() || right_indices.is_empty() {
91 return Ok(KdNode {
92 split_dim: 0,
93 split_value: 0.0,
94 left: None,
95 right: None,
96 indices,
97 });
98 }
99
100 let left = Some(Box::new(self.build_node(
101 points,
102 left_indices,
103 depth + 1,
104 )?));
105
106 let right = Some(Box::new(self.build_node(
107 points,
108 right_indices,
109 depth + 1,
110 )?));
111
112 Ok(KdNode {
113 split_dim,
114 split_value,
115 left,
116 right,
117 indices: Vec::new(),
118 })
119 }
120
121 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
122 if self.root.is_none() {
123 return Vec::new();
124 }
125
126 let mut heap = BinaryHeap::new();
127 self.search_node(
128 self.root
129 .as_ref()
130 .expect("tree should have root after build"),
131 query,
132 k,
133 &mut heap,
134 );
135
136 let mut results: Vec<(usize, f32)> =
137 heap.into_iter().map(|r| (r.index, r.distance)).collect();
138
139 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
140 results
141 }
142
143 fn search_node(
144 &self,
145 node: &KdNode,
146 query: &[f32],
147 k: usize,
148 heap: &mut BinaryHeap<SearchResult>,
149 ) {
150 if !node.indices.is_empty() {
151 for &idx in &node.indices {
153 let point = &self.data[idx].1.as_f32();
154 let dist = self.config.distance_metric.distance(query, point);
155
156 if heap.len() < k {
157 heap.push(SearchResult {
158 index: idx,
159 distance: dist,
160 });
161 } else if dist < heap.peek().expect("heap should have k elements").distance {
162 heap.pop();
163 heap.push(SearchResult {
164 index: idx,
165 distance: dist,
166 });
167 }
168 }
169 return;
170 }
171
172 let go_left = query[node.split_dim] <= node.split_value;
174
175 let (first, second) = if go_left {
176 (&node.left, &node.right)
177 } else {
178 (&node.right, &node.left)
179 };
180
181 if let Some(child) = first {
183 self.search_node(child, query, k, heap);
184 }
185
186 if heap.len() < k || {
188 let split_dist = (query[node.split_dim] - node.split_value).abs();
189 split_dist < heap.peek().expect("heap should have k elements").distance
190 } {
191 if let Some(child) = second {
192 self.search_node(child, query, k, heap);
193 }
194 }
195 }
196}