Skip to main content

yscv_recognize/
vp_tree.rs

1use super::similarity::cosine_similarity_prevalidated;
2
3/// Vantage-point tree for approximate nearest-neighbor search.
4#[derive(Debug, Clone)]
5pub struct VpTree {
6    nodes: Vec<VpNode>,
7    embeddings: Vec<Vec<f32>>,
8    ids: Vec<String>,
9}
10
11#[derive(Debug, Clone)]
12struct VpNode {
13    index: usize,
14    threshold: f32,
15    left: Option<usize>,
16    right: Option<usize>,
17}
18
19/// A k-nearest-neighbor result entry.
20#[derive(Debug, Clone, PartialEq)]
21pub struct KnnResult {
22    pub id: String,
23    pub distance: f32,
24}
25
26impl VpTree {
27    /// Create an empty VP-tree.
28    pub fn new() -> Self {
29        Self {
30            nodes: Vec::new(),
31            embeddings: Vec::new(),
32            ids: Vec::new(),
33        }
34    }
35
36    /// Build a VP-tree from a list of (id, embedding) pairs.
37    pub fn build(entries: Vec<(String, Vec<f32>)>) -> Self {
38        if entries.is_empty() {
39            return Self::new();
40        }
41
42        let mut ids: Vec<String> = Vec::with_capacity(entries.len());
43        let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(entries.len());
44        for (id, emb) in entries {
45            ids.push(id);
46            embeddings.push(emb);
47        }
48
49        let mut tree = VpTree {
50            nodes: Vec::with_capacity(embeddings.len()),
51            embeddings,
52            ids,
53        };
54
55        let indices: Vec<usize> = (0..tree.embeddings.len()).collect();
56        tree.build_recursive(&indices);
57        tree
58    }
59
60    fn build_recursive(&mut self, indices: &[usize]) -> Option<usize> {
61        if indices.is_empty() {
62            return None;
63        }
64
65        if indices.len() == 1 {
66            let node_index = self.nodes.len();
67            self.nodes.push(VpNode {
68                index: indices[0],
69                threshold: 0.0,
70                left: None,
71                right: None,
72            });
73            return Some(node_index);
74        }
75
76        // Pick the last element as the vantage point (deterministic).
77        let vp_idx = indices[indices.len() - 1];
78        let rest: Vec<usize> = indices[..indices.len() - 1].to_vec();
79
80        // Compute distances from vantage point to all others.
81        let mut dists: Vec<(usize, f32)> = rest
82            .iter()
83            .map(|&i| {
84                let d = cosine_distance(&self.embeddings[vp_idx], &self.embeddings[i]);
85                (i, d)
86            })
87            .collect();
88
89        // Sort by distance to find the median.
90        dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
91
92        let median_pos = dists.len() / 2;
93        let threshold = dists[median_pos].1;
94
95        // Split into inside (dist < threshold) and outside (dist >= threshold).
96        let inside: Vec<usize> = dists[..median_pos].iter().map(|&(i, _)| i).collect();
97        let outside: Vec<usize> = dists[median_pos..].iter().map(|&(i, _)| i).collect();
98
99        // Reserve the node index before recursing.
100        let node_index = self.nodes.len();
101        self.nodes.push(VpNode {
102            index: vp_idx,
103            threshold,
104            left: None,
105            right: None,
106        });
107
108        let left = self.build_recursive(&inside);
109        let right = self.build_recursive(&outside);
110
111        self.nodes[node_index].left = left;
112        self.nodes[node_index].right = right;
113
114        Some(node_index)
115    }
116
117    /// Query the tree for the k nearest neighbors to `point`.
118    ///
119    /// Returns a list of `(id, distance)` pairs sorted by ascending distance.
120    pub fn query(&self, point: &[f32], k: usize) -> Vec<KnnResult> {
121        if self.nodes.is_empty() || k == 0 {
122            return Vec::new();
123        }
124
125        let mut heap = BoundedMaxHeap::new(k);
126        self.search(0, point, &mut heap);
127
128        let mut results: Vec<KnnResult> = heap
129            .entries
130            .into_iter()
131            .map(|(dist, idx)| KnnResult {
132                id: self.ids[idx].clone(),
133                distance: dist,
134            })
135            .collect();
136        results.sort_by(|a, b| {
137            a.distance
138                .partial_cmp(&b.distance)
139                .unwrap_or(std::cmp::Ordering::Equal)
140        });
141        results
142    }
143
144    fn search(&self, node_idx: usize, point: &[f32], heap: &mut BoundedMaxHeap) {
145        let node = &self.nodes[node_idx];
146        let dist = cosine_distance(point, &self.embeddings[node.index]);
147
148        heap.push(dist, node.index);
149
150        let tau = heap.max_dist();
151
152        if dist < node.threshold {
153            // Point is inside; search inside first.
154            if let Some(left) = node.left
155                && dist - tau < node.threshold
156            {
157                self.search(left, point, heap);
158            }
159            if let Some(right) = node.right
160                && dist + tau >= node.threshold
161            {
162                self.search(right, point, heap);
163            }
164        } else {
165            // Point is outside; search outside first.
166            if let Some(right) = node.right
167                && dist + tau >= node.threshold
168            {
169                self.search(right, point, heap);
170            }
171            if let Some(left) = node.left
172                && dist - tau < node.threshold
173            {
174                self.search(left, point, heap);
175            }
176        }
177    }
178
179    /// Number of embeddings in the tree.
180    pub fn len(&self) -> usize {
181        self.embeddings.len()
182    }
183
184    /// Whether the tree is empty.
185    pub fn is_empty(&self) -> bool {
186        self.embeddings.is_empty()
187    }
188}
189
190impl Default for VpTree {
191    fn default() -> Self {
192        Self::new()
193    }
194}
195
196/// Distance metric: 1.0 - cosine_similarity.
197fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
198    let sim = cosine_similarity_prevalidated(a, b).unwrap_or(0.0);
199    1.0 - sim
200}
201
202/// A max-heap with bounded capacity k, keeping the k smallest distances.
203struct BoundedMaxHeap {
204    capacity: usize,
205    entries: Vec<(f32, usize)>, // (distance, embedding_index)
206}
207
208impl BoundedMaxHeap {
209    fn new(capacity: usize) -> Self {
210        Self {
211            capacity,
212            entries: Vec::with_capacity(capacity + 1),
213        }
214    }
215
216    fn push(&mut self, dist: f32, index: usize) {
217        if self.entries.len() < self.capacity {
218            self.entries.push((dist, index));
219            // Sift up to maintain max-heap property
220            let mut i = self.entries.len() - 1;
221            while i > 0 {
222                let parent = (i - 1) / 2;
223                if self.entries[i].0 > self.entries[parent].0 {
224                    self.entries.swap(i, parent);
225                    i = parent;
226                } else {
227                    break;
228                }
229            }
230        } else if dist < self.entries[0].0 {
231            self.entries[0] = (dist, index);
232            // Sift down to maintain max-heap property
233            let mut i = 0;
234            let n = self.entries.len();
235            loop {
236                let left = 2 * i + 1;
237                let right = 2 * i + 2;
238                let mut largest = i;
239                if left < n && self.entries[left].0 > self.entries[largest].0 {
240                    largest = left;
241                }
242                if right < n && self.entries[right].0 > self.entries[largest].0 {
243                    largest = right;
244                }
245                if largest == i {
246                    break;
247                }
248                self.entries.swap(i, largest);
249                i = largest;
250            }
251        }
252    }
253
254    fn max_dist(&self) -> f32 {
255        if self.entries.len() < self.capacity {
256            f32::INFINITY
257        } else {
258            self.entries[0].0
259        }
260    }
261}