Skip to main content

oxirs_vec/
tree_indices_rptree.rs

1//! Random Projection Tree implementation for approximate nearest-neighbor search.
2//!
3//! A random projection tree recursively partitions points using randomly
4//! generated projection vectors, providing fast approximate search.
5
6use crate::tree_indices_types::{SearchResult, TreeIndexConfig};
7use crate::Vector;
8use anyhow::Result;
9use oxirs_core::simd::SimdOps;
10use scirs2_core::random::{Random, Rng, RngExt};
11use std::cmp::Ordering;
12use std::collections::BinaryHeap;
13
14/// Random Projection Tree implementation
15pub struct RandomProjectionTree {
16    pub(crate) root: Option<Box<RpNode>>,
17    pub(crate) data: Vec<(String, Vector)>,
18    pub(crate) config: TreeIndexConfig,
19}
20
21pub(crate) struct RpNode {
22    /// Random projection vector
23    projection: Vec<f32>,
24    /// Projection threshold
25    threshold: f32,
26    /// Left child (projection <= threshold)
27    left: Option<Box<RpNode>>,
28    /// Right child (projection > threshold)
29    right: Option<Box<RpNode>>,
30    /// Indices for leaf nodes
31    indices: Vec<usize>,
32}
33
34impl RandomProjectionTree {
35    pub fn new(config: TreeIndexConfig) -> Self {
36        Self {
37            root: None,
38            data: Vec::new(),
39            config,
40        }
41    }
42
43    pub fn build(&mut self) -> Result<()> {
44        if self.data.is_empty() {
45            return Ok(());
46        }
47
48        let indices: Vec<usize> = (0..self.data.len()).collect();
49        let dimensions = self.data[0].1.dimensions;
50
51        let mut rng = if let Some(seed) = self.config.random_seed {
52            Random::seed(seed)
53        } else {
54            Random::seed(42)
55        };
56
57        self.root = Some(Box::new(self.build_node(indices, dimensions, &mut rng)?));
58        Ok(())
59    }
60
61    fn build_node<R: Rng>(
62        &self,
63        indices: Vec<usize>,
64        dimensions: usize,
65        rng: &mut R,
66    ) -> Result<RpNode> {
67        self.build_node_safe(indices, dimensions, rng, 0)
68    }
69
70    #[allow(deprecated)]
71    fn build_node_safe<R: Rng>(
72        &self,
73        indices: Vec<usize>,
74        dimensions: usize,
75        rng: &mut R,
76        depth: usize,
77    ) -> Result<RpNode> {
78        // Very strict stack overflow prevention - similar to BallTree approach
79        if indices.len() <= self.config.max_leaf_size || indices.len() <= 2 || depth >= 5 {
80            return Ok(RpNode {
81                projection: Vec::new(),
82                threshold: 0.0,
83                left: None,
84                right: None,
85                indices,
86            });
87        }
88
89        // Generate random projection vector
90        let projection: Vec<f32> = (0..dimensions)
91            .map(|_| rng.random_range(-1.0..1.0))
92            .collect();
93
94        // Normalize projection vector
95        let norm = (projection.iter().map(|&x| x * x).sum::<f32>()).sqrt();
96        let projection: Vec<f32> = if norm > 0.0 {
97            projection.iter().map(|&x| x / norm).collect()
98        } else {
99            projection
100        };
101
102        // Project all points
103        let mut projections: Vec<(f32, usize)> = indices
104            .iter()
105            .map(|&idx| {
106                let point = &self.data[idx].1.as_f32();
107                let proj_val = f32::dot(point, &projection);
108                (proj_val, idx)
109            })
110            .collect();
111
112        projections.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
113
114        // Choose median as threshold
115        let median_idx = projections.len() / 2;
116        let threshold = projections[median_idx].0;
117
118        let left_indices: Vec<usize> = projections[..median_idx]
119            .iter()
120            .map(|(_, idx)| *idx)
121            .collect();
122
123        let right_indices: Vec<usize> = projections[median_idx..]
124            .iter()
125            .map(|(_, idx)| *idx)
126            .collect();
127
128        // Prevent creating empty partitions - create leaf instead
129        if left_indices.is_empty() || right_indices.is_empty() {
130            return Ok(RpNode {
131                projection: Vec::new(),
132                threshold: 0.0,
133                left: None,
134                right: None,
135                indices,
136            });
137        }
138
139        let left = Some(Box::new(self.build_node_safe(
140            left_indices,
141            dimensions,
142            rng,
143            depth + 1,
144        )?));
145        let right = Some(Box::new(self.build_node_safe(
146            right_indices,
147            dimensions,
148            rng,
149            depth + 1,
150        )?));
151
152        Ok(RpNode {
153            projection,
154            threshold,
155            left,
156            right,
157            indices: Vec::new(),
158        })
159    }
160
161    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
162        if self.root.is_none() {
163            return Vec::new();
164        }
165
166        let mut heap = BinaryHeap::new();
167        self.search_node(
168            self.root
169                .as_ref()
170                .expect("tree should have root after build"),
171            query,
172            k,
173            &mut heap,
174        );
175
176        let mut results: Vec<(usize, f32)> =
177            heap.into_iter().map(|r| (r.index, r.distance)).collect();
178
179        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
180        results
181    }
182
183    fn search_node(
184        &self,
185        node: &RpNode,
186        query: &[f32],
187        k: usize,
188        heap: &mut BinaryHeap<SearchResult>,
189    ) {
190        if !node.indices.is_empty() {
191            // Leaf node
192            for &idx in &node.indices {
193                let point = &self.data[idx].1.as_f32();
194                let dist = self.config.distance_metric.distance(query, point);
195
196                if heap.len() < k {
197                    heap.push(SearchResult {
198                        index: idx,
199                        distance: dist,
200                    });
201                } else if dist < heap.peek().expect("heap should have k elements").distance {
202                    heap.pop();
203                    heap.push(SearchResult {
204                        index: idx,
205                        distance: dist,
206                    });
207                }
208            }
209            return;
210        }
211
212        // Project query
213        let query_projection = f32::dot(query, &node.projection);
214
215        // Determine which side to search first
216        let go_left = query_projection <= node.threshold;
217
218        let (first, second) = if go_left {
219            (&node.left, &node.right)
220        } else {
221            (&node.right, &node.left)
222        };
223
224        // Search both sides (random projections don't provide distance bounds)
225        if let Some(child) = first {
226            self.search_node(child, query, k, heap);
227        }
228
229        if let Some(child) = second {
230            self.search_node(child, query, k, heap);
231        }
232    }
233}