1use ndarray::Array1;
2use std::cmp::Ordering;
3
4pub struct ANNIndex {
6 vectors: Vec<Array1<f32>>,
8 data: Vec<f32>,
10 lsh: Option<crate::lsh::LSH>,
12 use_random_projections: bool,
14 projection_matrix: Option<Array2<f32>>,
16 projected_dim: usize,
18 vector_dim: usize,
20}
21
22#[derive(Debug, Clone)]
24pub struct ANNResult {
25 pub vector: Array1<f32>,
26 pub data: f32,
27 pub similarity: f32,
28}
29
30impl PartialEq for ANNResult {
31 fn eq(&self, other: &Self) -> bool {
32 self.similarity == other.similarity
33 }
34}
35
36impl Eq for ANNResult {}
37
38impl PartialOrd for ANNResult {
39 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
40 Some(self.cmp(other))
41 }
42}
43
44impl Ord for ANNResult {
45 fn cmp(&self, other: &Self) -> Ordering {
46 other
48 .similarity
49 .partial_cmp(&self.similarity)
50 .unwrap_or(Ordering::Equal)
51 }
52}
53
54impl ANNIndex {
55 pub fn new(vector_dim: usize) -> Self {
57 Self {
58 vectors: Vec::new(),
59 data: Vec::new(),
60 lsh: None,
61 use_random_projections: false,
62 projection_matrix: None,
63 projected_dim: vector_dim / 4, vector_dim,
65 }
66 }
67
68 pub fn with_lsh(mut self, num_tables: usize, hash_size: usize) -> Self {
70 self.lsh = Some(crate::lsh::LSH::new(self.vector_dim, num_tables, hash_size));
71 self
72 }
73
74 pub fn with_random_projections(mut self, projected_dim: usize) -> Self {
76 self.use_random_projections = true;
77 self.projected_dim = projected_dim;
78 self
79 }
80
81 pub fn add_vector(&mut self, vector: Array1<f32>, data: f32) {
83 if self.use_random_projections && self.projection_matrix.is_none() {
85 self.init_random_projections(vector.len());
86 }
87
88 self.vectors.push(vector.clone());
89 self.data.push(data);
90
91 if let Some(ref mut lsh) = self.lsh {
93 lsh.add_vector(vector, data);
94 }
95 }
96
97 pub fn search(
99 &self,
100 query: &Array1<f32>,
101 k: usize,
102 strategy: SearchStrategy,
103 ) -> Vec<ANNResult> {
104 match strategy {
105 SearchStrategy::LSH => self.search_lsh(query, k),
106 SearchStrategy::RandomProjection => self.search_random_projection(query, k),
107 SearchStrategy::Hybrid => self.search_hybrid(query, k),
108 SearchStrategy::Exact => self.search_exact(query, k),
109 }
110 }
111
112 fn search_lsh(&self, query: &Array1<f32>, k: usize) -> Vec<ANNResult> {
114 if let Some(ref lsh) = self.lsh {
115 lsh.query(query, k)
116 .into_iter()
117 .map(|(vec, data, sim)| ANNResult {
118 vector: vec,
119 data,
120 similarity: sim,
121 })
122 .collect()
123 } else {
124 self.search_exact(query, k)
125 }
126 }
127
128 fn search_random_projection(&self, query: &Array1<f32>, k: usize) -> Vec<ANNResult> {
130 if !self.use_random_projections || self.projection_matrix.is_none() {
131 return self.search_exact(query, k);
132 }
133
134 let proj_matrix = self.projection_matrix.as_ref().unwrap();
135 let proj_query = self.project_vector(query, proj_matrix);
136
137 let mut results: Vec<_> = self
138 .vectors
139 .iter()
140 .zip(self.data.iter())
141 .map(|(vec, &data)| {
142 let proj_vec = self.project_vector(vec, proj_matrix);
143 let similarity = cosine_similarity(&proj_query, &proj_vec);
144 ANNResult {
145 vector: vec.clone(),
146 data,
147 similarity,
148 }
149 })
150 .collect();
151
152 results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
153 results.truncate(k);
154 results
155 }
156
157 fn search_hybrid(&self, query: &Array1<f32>, k: usize) -> Vec<ANNResult> {
159 let mut candidate_indices = std::collections::HashSet::new();
160 let mut results = Vec::new();
161
162 if let Some(ref lsh) = self.lsh {
164 let lsh_results = lsh.query(query, k * 2);
165 for (vec, _data, _) in lsh_results {
166 for (idx, stored_vec) in self.vectors.iter().enumerate() {
168 if vectors_approximately_equal(&vec, stored_vec) {
169 candidate_indices.insert(idx);
170 break;
171 }
172 }
173 }
174 }
175
176 if self.use_random_projections {
178 let rp_results = self.search_random_projection(query, k * 2);
179 for result in rp_results {
180 for (idx, stored_vec) in self.vectors.iter().enumerate() {
182 if vectors_approximately_equal(&result.vector, stored_vec) {
183 candidate_indices.insert(idx);
184 break;
185 }
186 }
187 }
188 }
189
190 if candidate_indices.len() < k * 3 {
192 for idx in 0..(k * 3).min(self.vectors.len()) {
193 candidate_indices.insert(idx);
194 }
195 }
196
197 for &idx in &candidate_indices {
199 let vec = &self.vectors[idx];
200 let data = self.data[idx];
201 let similarity = cosine_similarity(query, vec);
202 results.push(ANNResult {
203 vector: vec.clone(),
204 data,
205 similarity,
206 });
207 }
208
209 results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
210 results.truncate(k);
211 results
212 }
213
214 fn search_exact(&self, query: &Array1<f32>, k: usize) -> Vec<ANNResult> {
216 let mut results: Vec<_> = self
217 .vectors
218 .iter()
219 .zip(self.data.iter())
220 .map(|(vec, &data)| {
221 let similarity = cosine_similarity(query, vec);
222 ANNResult {
223 vector: vec.clone(),
224 data,
225 similarity,
226 }
227 })
228 .collect();
229
230 results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
231 results.truncate(k);
232 results
233 }
234
235 fn init_random_projections(&mut self, input_dim: usize) {
237 use rand::Rng;
238 let mut rng = rand::thread_rng();
239
240 assert_eq!(
242 input_dim, self.vector_dim,
243 "Input dimension should match vector dimension"
244 );
245
246 let mut matrix_data = Vec::with_capacity(self.projected_dim * input_dim);
247 for _ in 0..(self.projected_dim * input_dim) {
248 matrix_data.push(rng.gen_range(-1.0..1.0));
249 }
250
251 self.projection_matrix = Some(
252 Array2::from_shape_vec((self.projected_dim, input_dim), matrix_data)
253 .expect("Failed to create projection matrix"),
254 );
255 }
256
257 fn project_vector(&self, vector: &Array1<f32>, proj_matrix: &Array2<f32>) -> Array1<f32> {
259 let mut result = Array1::zeros(self.projected_dim);
260 for i in 0..self.projected_dim {
261 let dot_product: f32 = vector
262 .iter()
263 .zip(proj_matrix.row(i).iter())
264 .map(|(v, p)| v * p)
265 .sum();
266 result[i] = dot_product;
267 }
268 result
269 }
270
271 pub fn stats(&self) -> ANNStats {
273 ANNStats {
274 num_vectors: self.vectors.len(),
275 vector_dim: if self.vectors.is_empty() {
276 0
277 } else {
278 self.vectors[0].len()
279 },
280 has_lsh: self.lsh.is_some(),
281 has_random_projections: self.use_random_projections,
282 projected_dim: if self.use_random_projections {
283 Some(self.projected_dim)
284 } else {
285 None
286 },
287 }
288 }
289}
290
291#[derive(Debug, Clone, Copy)]
293pub enum SearchStrategy {
294 LSH,
296 RandomProjection,
298 Hybrid,
300 Exact,
302}
303
304#[derive(Debug)]
306pub struct ANNStats {
307 pub num_vectors: usize,
308 pub vector_dim: usize,
309 pub has_lsh: bool,
310 pub has_random_projections: bool,
311 pub projected_dim: Option<usize>,
312}
313
314fn cosine_similarity(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
316 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
317 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
318 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
319
320 if norm_a == 0.0 || norm_b == 0.0 {
321 0.0
322 } else {
323 dot_product / (norm_a * norm_b)
324 }
325}
326
327fn vectors_approximately_equal(a: &Array1<f32>, b: &Array1<f32>) -> bool {
329 if a.len() != b.len() {
330 return false;
331 }
332
333 let threshold = 1e-6;
334 for (x, y) in a.iter().zip(b.iter()) {
335 if (x - y).abs() > threshold {
336 return false;
337 }
338 }
339 true
340}
341
342use ndarray::Array2;
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use ndarray::Array1;
348
349 #[test]
350 fn test_ann_index_creation() {
351 let index = ANNIndex::new(128);
352 assert_eq!(index.vectors.len(), 0);
353 assert!(!index.use_random_projections);
354 assert!(index.lsh.is_none());
355 }
356
357 #[test]
358 fn test_ann_with_lsh() {
359 let index = ANNIndex::new(128).with_lsh(4, 8);
360 assert!(index.lsh.is_some());
361 }
362
363 #[test]
364 fn test_ann_with_random_projections() {
365 let index = ANNIndex::new(128).with_random_projections(32);
366 assert!(index.use_random_projections);
367 assert_eq!(index.projected_dim, 32);
368 }
369
370 #[test]
371 fn test_add_and_search() {
372 let mut index = ANNIndex::new(4);
373
374 let vec1 = Array1::from(vec![1.0, 0.0, 0.0, 0.0]);
375 let vec2 = Array1::from(vec![0.0, 1.0, 0.0, 0.0]);
376 let vec3 = Array1::from(vec![1.0, 0.1, 0.0, 0.0]);
377
378 index.add_vector(vec1.clone(), 1.0);
379 index.add_vector(vec2, 2.0);
380 index.add_vector(vec3, 1.1);
381
382 let results = index.search(&vec1, 2, SearchStrategy::Exact);
383 assert_eq!(results.len(), 2);
384 assert!(results[0].similarity > 0.9); }
386
387 #[test]
388 fn test_search_strategies() {
389 let mut index = ANNIndex::new(4).with_lsh(2, 4).with_random_projections(2);
390
391 let vec1 = Array1::from(vec![1.0, 0.0, 0.0, 0.0]);
392 index.add_vector(vec1.clone(), 1.0);
393
394 let exact = index.search(&vec1, 1, SearchStrategy::Exact);
396 let lsh = index.search(&vec1, 1, SearchStrategy::LSH);
397 let rp = index.search(&vec1, 1, SearchStrategy::RandomProjection);
398 let hybrid = index.search(&vec1, 1, SearchStrategy::Hybrid);
399
400 assert!(!exact.is_empty());
401 assert!(!lsh.is_empty());
402 assert!(!rp.is_empty());
403 assert!(!hybrid.is_empty());
404 }
405}