Skip to main content

oxirs_vec/
tree_indices_vptree.rs

1//! VP-Tree (Vantage Point Tree) implementation for metric-space search.
2//!
3//! A VP-tree partitions points by their distance from a randomly chosen
4//! vantage point, supporting efficient search in arbitrary metric spaces.
5
6use crate::tree_indices_types::{SearchResult, TreeIndexConfig};
7use crate::Vector;
8use anyhow::Result;
9use scirs2_core::random::{Random, Rng, RngExt};
10use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12
13/// VP-Tree (Vantage Point Tree) implementation
14pub struct VpTree {
15    pub(crate) root: Option<Box<VpNode>>,
16    pub(crate) data: Vec<(String, Vector)>,
17    pub(crate) config: TreeIndexConfig,
18}
19
20pub(crate) struct VpNode {
21    /// Vantage point index
22    vantage_point: usize,
23    /// Median distance from vantage point
24    median_distance: f32,
25    /// Points closer than median
26    inside: Option<Box<VpNode>>,
27    /// Points farther than median
28    outside: Option<Box<VpNode>>,
29    /// Indices for leaf nodes
30    indices: Vec<usize>,
31}
32
33impl VpTree {
34    pub fn new(config: TreeIndexConfig) -> Self {
35        Self {
36            root: None,
37            data: Vec::new(),
38            config,
39        }
40    }
41
42    pub fn build(&mut self) -> Result<()> {
43        if self.data.is_empty() {
44            return Ok(());
45        }
46
47        let indices: Vec<usize> = (0..self.data.len()).collect();
48        let mut rng = if let Some(seed) = self.config.random_seed {
49            Random::seed(seed)
50        } else {
51            Random::seed(42)
52        };
53
54        self.root = Some(Box::new(self.build_node(indices, &mut rng)?));
55        Ok(())
56    }
57
58    fn build_node<R: Rng>(&self, indices: Vec<usize>, rng: &mut R) -> Result<VpNode> {
59        self.build_node_safe(indices, rng, 0)
60    }
61
62    #[allow(deprecated)]
63    fn build_node_safe<R: Rng>(
64        &self,
65        mut indices: Vec<usize>,
66        rng: &mut R,
67        depth: usize,
68    ) -> Result<VpNode> {
69        // Note: Using manual random selection instead of SliceRandom
70
71        // CRITICAL: Extremely strict depth and size limits to prevent stack overflow
72        // For very small datasets or deep recursion, immediately create leaf nodes
73        let max_depth = 30; // Conservative depth limit
74
75        // Aggressive leaf node creation for small datasets
76        if indices.len() <= self.config.max_leaf_size
77            || indices.len() <= 2  // Changed from <= 1 to <= 2 for extra safety
78            || depth >= max_depth
79        {
80            return Ok(VpNode {
81                vantage_point: if indices.is_empty() { 0 } else { indices[0] },
82                median_distance: 0.0,
83                inside: None,
84                outside: None,
85                indices,
86            });
87        }
88
89        // Choose random vantage point - simplified to avoid potential issues
90        let vp_idx = if indices.len() > 1 {
91            rng.random_range(0..indices.len())
92        } else {
93            0
94        };
95        let vantage_point = indices[vp_idx];
96        indices.remove(vp_idx);
97
98        // Calculate distances from vantage point
99        let vp_data = &self.data[vantage_point].1.as_f32();
100        let mut distances: Vec<(f32, usize)> = indices
101            .iter()
102            .map(|&idx| {
103                let point = &self.data[idx].1.as_f32();
104                let dist = self.config.distance_metric.distance(vp_data, point);
105                (dist, idx)
106            })
107            .collect();
108
109        distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
110
111        let median_idx = distances.len() / 2;
112        let median_distance = distances[median_idx].0;
113
114        let inside_indices: Vec<usize> = distances[..median_idx]
115            .iter()
116            .map(|(_, idx)| *idx)
117            .collect();
118
119        let outside_indices: Vec<usize> = distances[median_idx..]
120            .iter()
121            .map(|(_, idx)| *idx)
122            .collect();
123
124        // Prevent creating empty partitions - create leaf instead
125        if inside_indices.is_empty() || outside_indices.is_empty() {
126            return Ok(VpNode {
127                vantage_point: if indices.is_empty() { 0 } else { indices[0] },
128                median_distance: 0.0,
129                inside: None,
130                outside: None,
131                indices,
132            });
133        }
134
135        let inside = Some(Box::new(self.build_node_safe(
136            inside_indices,
137            rng,
138            depth + 1,
139        )?));
140        let outside = Some(Box::new(self.build_node_safe(
141            outside_indices,
142            rng,
143            depth + 1,
144        )?));
145
146        Ok(VpNode {
147            vantage_point,
148            median_distance,
149            inside,
150            outside,
151            indices: Vec::new(),
152        })
153    }
154
155    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
156        if self.root.is_none() {
157            return Vec::new();
158        }
159
160        let mut heap = BinaryHeap::new();
161        self.search_node(
162            self.root
163                .as_ref()
164                .expect("tree should have root after build"),
165            query,
166            k,
167            &mut heap,
168            f32::INFINITY,
169        );
170
171        let mut results: Vec<(usize, f32)> =
172            heap.into_iter().map(|r| (r.index, r.distance)).collect();
173
174        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
175        results
176    }
177
178    fn search_node(
179        &self,
180        node: &VpNode,
181        query: &[f32],
182        k: usize,
183        heap: &mut BinaryHeap<SearchResult>,
184        tau: f32,
185    ) -> f32 {
186        let mut tau = tau;
187
188        if !node.indices.is_empty() {
189            // Leaf node
190            for &idx in &node.indices {
191                let point = &self.data[idx].1.as_f32();
192                let dist = self.config.distance_metric.distance(query, point);
193
194                if dist < tau {
195                    if heap.len() < k {
196                        heap.push(SearchResult {
197                            index: idx,
198                            distance: dist,
199                        });
200                    } else if dist < heap.peek().expect("heap should have k elements").distance {
201                        heap.pop();
202                        heap.push(SearchResult {
203                            index: idx,
204                            distance: dist,
205                        });
206                    }
207
208                    if heap.len() >= k {
209                        tau = heap.peek().expect("heap should have k elements").distance;
210                    }
211                }
212            }
213            return tau;
214        }
215
216        // Calculate distance to vantage point
217        let vp_data = &self.data[node.vantage_point].1.as_f32();
218        let dist_to_vp = self.config.distance_metric.distance(query, vp_data);
219
220        // Consider vantage point itself
221        if dist_to_vp < tau {
222            if heap.len() < k {
223                heap.push(SearchResult {
224                    index: node.vantage_point,
225                    distance: dist_to_vp,
226                });
227            } else if dist_to_vp < heap.peek().expect("heap should have k elements").distance {
228                heap.pop();
229                heap.push(SearchResult {
230                    index: node.vantage_point,
231                    distance: dist_to_vp,
232                });
233            }
234
235            if heap.len() >= k {
236                tau = heap.peek().expect("heap should have k elements").distance;
237            }
238        }
239
240        // Search children
241        if dist_to_vp < node.median_distance {
242            // Search inside first
243            if let Some(inside) = &node.inside {
244                tau = self.search_node(inside, query, k, heap, tau);
245            }
246
247            // Check if we need to search outside
248            if dist_to_vp + tau >= node.median_distance {
249                if let Some(outside) = &node.outside {
250                    tau = self.search_node(outside, query, k, heap, tau);
251                }
252            }
253        } else {
254            // Search outside first
255            if let Some(outside) = &node.outside {
256                tau = self.search_node(outside, query, k, heap, tau);
257            }
258
259            // Check if we need to search inside
260            if dist_to_vp - tau <= node.median_distance {
261                if let Some(inside) = &node.inside {
262                    tau = self.search_node(inside, query, k, heap, tau);
263                }
264            }
265        }
266
267        tau
268    }
269}