1use crate::tree_indices_types::{SearchResult, TreeIndexConfig};
7use crate::Vector;
8use anyhow::Result;
9use oxirs_core::simd::SimdOps;
10use scirs2_core::random::{Random, Rng, RngExt};
11use std::cmp::Ordering;
12use std::collections::BinaryHeap;
13
14pub struct RandomProjectionTree {
16 pub(crate) root: Option<Box<RpNode>>,
17 pub(crate) data: Vec<(String, Vector)>,
18 pub(crate) config: TreeIndexConfig,
19}
20
21pub(crate) struct RpNode {
22 projection: Vec<f32>,
24 threshold: f32,
26 left: Option<Box<RpNode>>,
28 right: Option<Box<RpNode>>,
30 indices: Vec<usize>,
32}
33
34impl RandomProjectionTree {
35 pub fn new(config: TreeIndexConfig) -> Self {
36 Self {
37 root: None,
38 data: Vec::new(),
39 config,
40 }
41 }
42
43 pub fn build(&mut self) -> Result<()> {
44 if self.data.is_empty() {
45 return Ok(());
46 }
47
48 let indices: Vec<usize> = (0..self.data.len()).collect();
49 let dimensions = self.data[0].1.dimensions;
50
51 let mut rng = if let Some(seed) = self.config.random_seed {
52 Random::seed(seed)
53 } else {
54 Random::seed(42)
55 };
56
57 self.root = Some(Box::new(self.build_node(indices, dimensions, &mut rng)?));
58 Ok(())
59 }
60
61 fn build_node<R: Rng>(
62 &self,
63 indices: Vec<usize>,
64 dimensions: usize,
65 rng: &mut R,
66 ) -> Result<RpNode> {
67 self.build_node_safe(indices, dimensions, rng, 0)
68 }
69
70 #[allow(deprecated)]
71 fn build_node_safe<R: Rng>(
72 &self,
73 indices: Vec<usize>,
74 dimensions: usize,
75 rng: &mut R,
76 depth: usize,
77 ) -> Result<RpNode> {
78 if indices.len() <= self.config.max_leaf_size || indices.len() <= 2 || depth >= 5 {
80 return Ok(RpNode {
81 projection: Vec::new(),
82 threshold: 0.0,
83 left: None,
84 right: None,
85 indices,
86 });
87 }
88
89 let projection: Vec<f32> = (0..dimensions)
91 .map(|_| rng.random_range(-1.0..1.0))
92 .collect();
93
94 let norm = (projection.iter().map(|&x| x * x).sum::<f32>()).sqrt();
96 let projection: Vec<f32> = if norm > 0.0 {
97 projection.iter().map(|&x| x / norm).collect()
98 } else {
99 projection
100 };
101
102 let mut projections: Vec<(f32, usize)> = indices
104 .iter()
105 .map(|&idx| {
106 let point = &self.data[idx].1.as_f32();
107 let proj_val = f32::dot(point, &projection);
108 (proj_val, idx)
109 })
110 .collect();
111
112 projections.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
113
114 let median_idx = projections.len() / 2;
116 let threshold = projections[median_idx].0;
117
118 let left_indices: Vec<usize> = projections[..median_idx]
119 .iter()
120 .map(|(_, idx)| *idx)
121 .collect();
122
123 let right_indices: Vec<usize> = projections[median_idx..]
124 .iter()
125 .map(|(_, idx)| *idx)
126 .collect();
127
128 if left_indices.is_empty() || right_indices.is_empty() {
130 return Ok(RpNode {
131 projection: Vec::new(),
132 threshold: 0.0,
133 left: None,
134 right: None,
135 indices,
136 });
137 }
138
139 let left = Some(Box::new(self.build_node_safe(
140 left_indices,
141 dimensions,
142 rng,
143 depth + 1,
144 )?));
145 let right = Some(Box::new(self.build_node_safe(
146 right_indices,
147 dimensions,
148 rng,
149 depth + 1,
150 )?));
151
152 Ok(RpNode {
153 projection,
154 threshold,
155 left,
156 right,
157 indices: Vec::new(),
158 })
159 }
160
161 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
162 if self.root.is_none() {
163 return Vec::new();
164 }
165
166 let mut heap = BinaryHeap::new();
167 self.search_node(
168 self.root
169 .as_ref()
170 .expect("tree should have root after build"),
171 query,
172 k,
173 &mut heap,
174 );
175
176 let mut results: Vec<(usize, f32)> =
177 heap.into_iter().map(|r| (r.index, r.distance)).collect();
178
179 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
180 results
181 }
182
183 fn search_node(
184 &self,
185 node: &RpNode,
186 query: &[f32],
187 k: usize,
188 heap: &mut BinaryHeap<SearchResult>,
189 ) {
190 if !node.indices.is_empty() {
191 for &idx in &node.indices {
193 let point = &self.data[idx].1.as_f32();
194 let dist = self.config.distance_metric.distance(query, point);
195
196 if heap.len() < k {
197 heap.push(SearchResult {
198 index: idx,
199 distance: dist,
200 });
201 } else if dist < heap.peek().expect("heap should have k elements").distance {
202 heap.pop();
203 heap.push(SearchResult {
204 index: idx,
205 distance: dist,
206 });
207 }
208 }
209 return;
210 }
211
212 let query_projection = f32::dot(query, &node.projection);
214
215 let go_left = query_projection <= node.threshold;
217
218 let (first, second) = if go_left {
219 (&node.left, &node.right)
220 } else {
221 (&node.right, &node.left)
222 };
223
224 if let Some(child) = first {
226 self.search_node(child, query, k, heap);
227 }
228
229 if let Some(child) = second {
230 self.search_node(child, query, k, heap);
231 }
232 }
233}