yscv_recognize/
vp_tree.rs1use super::similarity::cosine_similarity_prevalidated;
2
3#[derive(Debug, Clone)]
5pub struct VpTree {
6 nodes: Vec<VpNode>,
7 embeddings: Vec<Vec<f32>>,
8 ids: Vec<String>,
9}
10
11#[derive(Debug, Clone)]
12struct VpNode {
13 index: usize,
14 threshold: f32,
15 left: Option<usize>,
16 right: Option<usize>,
17}
18
19#[derive(Debug, Clone, PartialEq)]
21pub struct KnnResult {
22 pub id: String,
23 pub distance: f32,
24}
25
26impl VpTree {
27 pub fn new() -> Self {
29 Self {
30 nodes: Vec::new(),
31 embeddings: Vec::new(),
32 ids: Vec::new(),
33 }
34 }
35
36 pub fn build(entries: Vec<(String, Vec<f32>)>) -> Self {
38 if entries.is_empty() {
39 return Self::new();
40 }
41
42 let mut ids: Vec<String> = Vec::with_capacity(entries.len());
43 let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(entries.len());
44 for (id, emb) in entries {
45 ids.push(id);
46 embeddings.push(emb);
47 }
48
49 let mut tree = VpTree {
50 nodes: Vec::with_capacity(embeddings.len()),
51 embeddings,
52 ids,
53 };
54
55 let indices: Vec<usize> = (0..tree.embeddings.len()).collect();
56 tree.build_recursive(&indices);
57 tree
58 }
59
60 fn build_recursive(&mut self, indices: &[usize]) -> Option<usize> {
61 if indices.is_empty() {
62 return None;
63 }
64
65 if indices.len() == 1 {
66 let node_index = self.nodes.len();
67 self.nodes.push(VpNode {
68 index: indices[0],
69 threshold: 0.0,
70 left: None,
71 right: None,
72 });
73 return Some(node_index);
74 }
75
76 let vp_idx = indices[indices.len() - 1];
78 let rest: Vec<usize> = indices[..indices.len() - 1].to_vec();
79
80 let mut dists: Vec<(usize, f32)> = rest
82 .iter()
83 .map(|&i| {
84 let d = cosine_distance(&self.embeddings[vp_idx], &self.embeddings[i]);
85 (i, d)
86 })
87 .collect();
88
89 dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
91
92 let median_pos = dists.len() / 2;
93 let threshold = dists[median_pos].1;
94
95 let inside: Vec<usize> = dists[..median_pos].iter().map(|&(i, _)| i).collect();
97 let outside: Vec<usize> = dists[median_pos..].iter().map(|&(i, _)| i).collect();
98
99 let node_index = self.nodes.len();
101 self.nodes.push(VpNode {
102 index: vp_idx,
103 threshold,
104 left: None,
105 right: None,
106 });
107
108 let left = self.build_recursive(&inside);
109 let right = self.build_recursive(&outside);
110
111 self.nodes[node_index].left = left;
112 self.nodes[node_index].right = right;
113
114 Some(node_index)
115 }
116
117 pub fn query(&self, point: &[f32], k: usize) -> Vec<KnnResult> {
121 if self.nodes.is_empty() || k == 0 {
122 return Vec::new();
123 }
124
125 let mut heap = BoundedMaxHeap::new(k);
126 self.search(0, point, &mut heap);
127
128 let mut results: Vec<KnnResult> = heap
129 .entries
130 .into_iter()
131 .map(|(dist, idx)| KnnResult {
132 id: self.ids[idx].clone(),
133 distance: dist,
134 })
135 .collect();
136 results.sort_by(|a, b| {
137 a.distance
138 .partial_cmp(&b.distance)
139 .unwrap_or(std::cmp::Ordering::Equal)
140 });
141 results
142 }
143
144 fn search(&self, node_idx: usize, point: &[f32], heap: &mut BoundedMaxHeap) {
145 let node = &self.nodes[node_idx];
146 let dist = cosine_distance(point, &self.embeddings[node.index]);
147
148 heap.push(dist, node.index);
149
150 let tau = heap.max_dist();
151
152 if dist < node.threshold {
153 if let Some(left) = node.left
155 && dist - tau < node.threshold
156 {
157 self.search(left, point, heap);
158 }
159 if let Some(right) = node.right
160 && dist + tau >= node.threshold
161 {
162 self.search(right, point, heap);
163 }
164 } else {
165 if let Some(right) = node.right
167 && dist + tau >= node.threshold
168 {
169 self.search(right, point, heap);
170 }
171 if let Some(left) = node.left
172 && dist - tau < node.threshold
173 {
174 self.search(left, point, heap);
175 }
176 }
177 }
178
179 pub fn len(&self) -> usize {
181 self.embeddings.len()
182 }
183
184 pub fn is_empty(&self) -> bool {
186 self.embeddings.is_empty()
187 }
188}
189
190impl Default for VpTree {
191 fn default() -> Self {
192 Self::new()
193 }
194}
195
196fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
198 let sim = cosine_similarity_prevalidated(a, b).unwrap_or(0.0);
199 1.0 - sim
200}
201
202struct BoundedMaxHeap {
204 capacity: usize,
205 entries: Vec<(f32, usize)>, }
207
208impl BoundedMaxHeap {
209 fn new(capacity: usize) -> Self {
210 Self {
211 capacity,
212 entries: Vec::with_capacity(capacity + 1),
213 }
214 }
215
216 fn push(&mut self, dist: f32, index: usize) {
217 if self.entries.len() < self.capacity {
218 self.entries.push((dist, index));
219 let mut i = self.entries.len() - 1;
221 while i > 0 {
222 let parent = (i - 1) / 2;
223 if self.entries[i].0 > self.entries[parent].0 {
224 self.entries.swap(i, parent);
225 i = parent;
226 } else {
227 break;
228 }
229 }
230 } else if dist < self.entries[0].0 {
231 self.entries[0] = (dist, index);
232 let mut i = 0;
234 let n = self.entries.len();
235 loop {
236 let left = 2 * i + 1;
237 let right = 2 * i + 2;
238 let mut largest = i;
239 if left < n && self.entries[left].0 > self.entries[largest].0 {
240 largest = left;
241 }
242 if right < n && self.entries[right].0 > self.entries[largest].0 {
243 largest = right;
244 }
245 if largest == i {
246 break;
247 }
248 self.entries.swap(i, largest);
249 i = largest;
250 }
251 }
252 }
253
254 fn max_dist(&self) -> f32 {
255 if self.entries.len() < self.capacity {
256 f32::INFINITY
257 } else {
258 self.entries[0].0
259 }
260 }
261}