1use crate::tree_indices_types::{SearchResult, TreeIndexConfig};
7use crate::Vector;
8use anyhow::Result;
9use scirs2_core::random::{Random, Rng, RngExt};
10use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12
13pub struct VpTree {
15 pub(crate) root: Option<Box<VpNode>>,
16 pub(crate) data: Vec<(String, Vector)>,
17 pub(crate) config: TreeIndexConfig,
18}
19
20pub(crate) struct VpNode {
21 vantage_point: usize,
23 median_distance: f32,
25 inside: Option<Box<VpNode>>,
27 outside: Option<Box<VpNode>>,
29 indices: Vec<usize>,
31}
32
33impl VpTree {
34 pub fn new(config: TreeIndexConfig) -> Self {
35 Self {
36 root: None,
37 data: Vec::new(),
38 config,
39 }
40 }
41
42 pub fn build(&mut self) -> Result<()> {
43 if self.data.is_empty() {
44 return Ok(());
45 }
46
47 let indices: Vec<usize> = (0..self.data.len()).collect();
48 let mut rng = if let Some(seed) = self.config.random_seed {
49 Random::seed(seed)
50 } else {
51 Random::seed(42)
52 };
53
54 self.root = Some(Box::new(self.build_node(indices, &mut rng)?));
55 Ok(())
56 }
57
58 fn build_node<R: Rng>(&self, indices: Vec<usize>, rng: &mut R) -> Result<VpNode> {
59 self.build_node_safe(indices, rng, 0)
60 }
61
62 #[allow(deprecated)]
63 fn build_node_safe<R: Rng>(
64 &self,
65 mut indices: Vec<usize>,
66 rng: &mut R,
67 depth: usize,
68 ) -> Result<VpNode> {
69 let max_depth = 30; if indices.len() <= self.config.max_leaf_size
77 || indices.len() <= 2 || depth >= max_depth
79 {
80 return Ok(VpNode {
81 vantage_point: if indices.is_empty() { 0 } else { indices[0] },
82 median_distance: 0.0,
83 inside: None,
84 outside: None,
85 indices,
86 });
87 }
88
89 let vp_idx = if indices.len() > 1 {
91 rng.random_range(0..indices.len())
92 } else {
93 0
94 };
95 let vantage_point = indices[vp_idx];
96 indices.remove(vp_idx);
97
98 let vp_data = &self.data[vantage_point].1.as_f32();
100 let mut distances: Vec<(f32, usize)> = indices
101 .iter()
102 .map(|&idx| {
103 let point = &self.data[idx].1.as_f32();
104 let dist = self.config.distance_metric.distance(vp_data, point);
105 (dist, idx)
106 })
107 .collect();
108
109 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
110
111 let median_idx = distances.len() / 2;
112 let median_distance = distances[median_idx].0;
113
114 let inside_indices: Vec<usize> = distances[..median_idx]
115 .iter()
116 .map(|(_, idx)| *idx)
117 .collect();
118
119 let outside_indices: Vec<usize> = distances[median_idx..]
120 .iter()
121 .map(|(_, idx)| *idx)
122 .collect();
123
124 if inside_indices.is_empty() || outside_indices.is_empty() {
126 return Ok(VpNode {
127 vantage_point: if indices.is_empty() { 0 } else { indices[0] },
128 median_distance: 0.0,
129 inside: None,
130 outside: None,
131 indices,
132 });
133 }
134
135 let inside = Some(Box::new(self.build_node_safe(
136 inside_indices,
137 rng,
138 depth + 1,
139 )?));
140 let outside = Some(Box::new(self.build_node_safe(
141 outside_indices,
142 rng,
143 depth + 1,
144 )?));
145
146 Ok(VpNode {
147 vantage_point,
148 median_distance,
149 inside,
150 outside,
151 indices: Vec::new(),
152 })
153 }
154
155 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
156 if self.root.is_none() {
157 return Vec::new();
158 }
159
160 let mut heap = BinaryHeap::new();
161 self.search_node(
162 self.root
163 .as_ref()
164 .expect("tree should have root after build"),
165 query,
166 k,
167 &mut heap,
168 f32::INFINITY,
169 );
170
171 let mut results: Vec<(usize, f32)> =
172 heap.into_iter().map(|r| (r.index, r.distance)).collect();
173
174 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
175 results
176 }
177
178 fn search_node(
179 &self,
180 node: &VpNode,
181 query: &[f32],
182 k: usize,
183 heap: &mut BinaryHeap<SearchResult>,
184 tau: f32,
185 ) -> f32 {
186 let mut tau = tau;
187
188 if !node.indices.is_empty() {
189 for &idx in &node.indices {
191 let point = &self.data[idx].1.as_f32();
192 let dist = self.config.distance_metric.distance(query, point);
193
194 if dist < tau {
195 if heap.len() < k {
196 heap.push(SearchResult {
197 index: idx,
198 distance: dist,
199 });
200 } else if dist < heap.peek().expect("heap should have k elements").distance {
201 heap.pop();
202 heap.push(SearchResult {
203 index: idx,
204 distance: dist,
205 });
206 }
207
208 if heap.len() >= k {
209 tau = heap.peek().expect("heap should have k elements").distance;
210 }
211 }
212 }
213 return tau;
214 }
215
216 let vp_data = &self.data[node.vantage_point].1.as_f32();
218 let dist_to_vp = self.config.distance_metric.distance(query, vp_data);
219
220 if dist_to_vp < tau {
222 if heap.len() < k {
223 heap.push(SearchResult {
224 index: node.vantage_point,
225 distance: dist_to_vp,
226 });
227 } else if dist_to_vp < heap.peek().expect("heap should have k elements").distance {
228 heap.pop();
229 heap.push(SearchResult {
230 index: node.vantage_point,
231 distance: dist_to_vp,
232 });
233 }
234
235 if heap.len() >= k {
236 tau = heap.peek().expect("heap should have k elements").distance;
237 }
238 }
239
240 if dist_to_vp < node.median_distance {
242 if let Some(inside) = &node.inside {
244 tau = self.search_node(inside, query, k, heap, tau);
245 }
246
247 if dist_to_vp + tau >= node.median_distance {
249 if let Some(outside) = &node.outside {
250 tau = self.search_node(outside, query, k, heap, tau);
251 }
252 }
253 } else {
254 if let Some(outside) = &node.outside {
256 tau = self.search_node(outside, query, k, heap, tau);
257 }
258
259 if dist_to_vp - tau <= node.median_distance {
261 if let Some(inside) = &node.inside {
262 tau = self.search_node(inside, query, k, heap, tau);
263 }
264 }
265 }
266
267 tau
268 }
269}