1use crate::Vector;
4
5pub use crate::VectorIndex;
7use anyhow::{anyhow, Result};
8use oxirs_core::parallel::*;
9use oxirs_core::Triple;
10use serde::{Deserialize, Serialize};
11use std::cmp::Ordering;
12use std::collections::{BinaryHeap, HashMap};
13use std::sync::Arc;
14
15#[cfg(feature = "hnsw")]
16use hnsw_rs::prelude::*;
17
18pub type FilterFunction = Box<dyn Fn(&str) -> bool>;
20pub type FilterFunctionSync = Box<dyn Fn(&str) -> bool + Send + Sync>;
22
23#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
25pub struct IndexConfig {
26 pub index_type: IndexType,
28 pub max_connections: usize,
30 pub ef_construction: usize,
32 pub ef_search: usize,
34 pub distance_metric: DistanceMetric,
36 pub parallel: bool,
38}
39
40impl Default for IndexConfig {
41 fn default() -> Self {
42 Self {
43 index_type: IndexType::Hnsw,
44 max_connections: 16,
45 ef_construction: 200,
46 ef_search: 50,
47 distance_metric: DistanceMetric::Cosine,
48 parallel: true,
49 }
50 }
51}
52
53#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
55pub enum IndexType {
56 Hnsw,
58 Flat,
60 Ivf,
62 PQ,
64}
65
66#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
68pub enum DistanceMetric {
69 Cosine,
71 Euclidean,
73 Manhattan,
75 DotProduct,
77}
78
79impl DistanceMetric {
80 pub fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
82 use oxirs_core::simd::SimdOps;
83
84 match self {
85 DistanceMetric::Cosine => f32::cosine_distance(a, b),
86 DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
87 DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
88 DistanceMetric::DotProduct => -f32::dot(a, b), }
90 }
91
92 pub fn distance_vectors(&self, a: &Vector, b: &Vector) -> f32 {
94 let a_f32 = a.as_f32();
95 let b_f32 = b.as_f32();
96 self.distance(&a_f32, &b_f32)
97 }
98}
99
100#[derive(Debug, Clone, PartialEq)]
102pub struct SearchResult {
103 pub uri: String,
104 pub distance: f32,
105 pub score: f32,
106 pub metadata: Option<HashMap<String, String>>,
107}
108
109impl Eq for SearchResult {}
110
111impl Ord for SearchResult {
112 fn cmp(&self, other: &Self) -> Ordering {
113 self.distance
114 .partial_cmp(&other.distance)
115 .unwrap_or(Ordering::Equal)
116 }
117}
118
119impl PartialOrd for SearchResult {
120 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
121 Some(self.cmp(other))
122 }
123}
124
125pub struct AdvancedVectorIndex {
127 config: IndexConfig,
128 vectors: Vec<(String, Vector)>,
129 uri_to_id: HashMap<String, usize>,
130 #[cfg(feature = "hnsw")]
131 hnsw_index: Option<Hnsw<'static, f32, DistCosine>>,
132 dimensions: Option<usize>,
133}
134
135impl AdvancedVectorIndex {
136 pub fn new(config: IndexConfig) -> Self {
137 Self {
138 config,
139 vectors: Vec::new(),
140 uri_to_id: HashMap::new(),
141 #[cfg(feature = "hnsw")]
142 hnsw_index: None,
143 dimensions: None,
144 }
145 }
146
147 pub fn build(&mut self) -> Result<()> {
149 if self.vectors.is_empty() {
150 return Ok(());
151 }
152
153 match self.config.index_type {
154 IndexType::Hnsw => {
155 #[cfg(feature = "hnsw")]
156 {
157 self.build_hnsw_index()?;
158 }
159 #[cfg(not(feature = "hnsw"))]
160 {
161 return Err(anyhow!("HNSW feature not enabled"));
162 }
163 }
164 IndexType::Flat => {
165 }
167 IndexType::Ivf | IndexType::PQ => {
168 return Err(anyhow!("IVF and PQ indices not yet implemented"));
169 }
170 }
171
172 Ok(())
173 }
174
175 #[cfg(feature = "hnsw")]
176 fn build_hnsw_index(&mut self) -> Result<()> {
177 if let Some(_dimensions) = self.dimensions {
178 let hnsw = Hnsw::<f32, DistCosine>::new(
179 self.config.max_connections,
180 self.vectors.len(),
181 16, self.config.ef_construction,
183 DistCosine,
184 );
185
186 for (id, (_, vector)) in self.vectors.iter().enumerate() {
187 let vector_f32 = vector.as_f32();
188 hnsw.insert((&vector_f32, id));
189 }
190
191 self.hnsw_index = Some(hnsw);
192 }
193
194 Ok(())
195 }
196
197 pub fn add_metadata(&mut self, _uri: &str, _metadata: HashMap<String, String>) -> Result<()> {
199 Ok(())
202 }
203
204 pub fn search_advanced(
206 &self,
207 query: &Vector,
208 k: usize,
209 ef: Option<usize>,
210 filter: Option<FilterFunction>,
211 ) -> Result<Vec<SearchResult>> {
212 match self.config.index_type {
213 IndexType::Hnsw => {
214 #[cfg(feature = "hnsw")]
215 {
216 self.search_hnsw(query, k, ef)
217 }
218 #[cfg(not(feature = "hnsw"))]
219 {
220 let _ = ef;
221 self.search_flat(query, k, filter)
222 }
223 }
224 _ => self.search_flat(query, k, filter),
225 }
226 }
227
228 #[cfg(feature = "hnsw")]
229 fn search_hnsw(
230 &self,
231 query: &Vector,
232 k: usize,
233 ef: Option<usize>,
234 ) -> Result<Vec<SearchResult>> {
235 if let Some(ref hnsw) = self.hnsw_index {
236 let search_ef = ef.unwrap_or(self.config.ef_search);
237 let query_f32 = query.as_f32();
238 let results = hnsw.search(&query_f32, k, search_ef);
239
240 Ok(results
241 .into_iter()
242 .map(|result| SearchResult {
243 uri: self.vectors[result.d_id].0.clone(),
244 distance: result.distance,
245 score: 1.0 - result.distance, metadata: None,
247 })
248 .collect())
249 } else {
250 Err(anyhow!("HNSW index not built"))
251 }
252 }
253
254 fn search_flat(
255 &self,
256 query: &Vector,
257 k: usize,
258 filter: Option<FilterFunction>,
259 ) -> Result<Vec<SearchResult>> {
260 if self.config.parallel && self.vectors.len() > 1000 {
261 if filter.is_some() {
263 self.search_flat_sequential(query, k, filter)
265 } else {
266 self.search_flat_parallel(query, k, None)
267 }
268 } else {
269 self.search_flat_sequential(query, k, filter)
270 }
271 }
272
273 fn search_flat_sequential(
274 &self,
275 query: &Vector,
276 k: usize,
277 filter: Option<FilterFunction>,
278 ) -> Result<Vec<SearchResult>> {
279 let mut heap = BinaryHeap::new();
280
281 for (uri, vector) in &self.vectors {
282 if let Some(ref filter_fn) = filter {
283 if !filter_fn(uri) {
284 continue;
285 }
286 }
287
288 let distance = self.config.distance_metric.distance_vectors(query, vector);
289
290 if heap.len() < k {
291 heap.push(std::cmp::Reverse(SearchResult {
292 uri: uri.clone(),
293 distance,
294 score: 1.0 - distance, metadata: None,
296 }));
297 } else if let Some(std::cmp::Reverse(worst)) = heap.peek() {
298 if distance < worst.distance {
299 heap.pop();
300 heap.push(std::cmp::Reverse(SearchResult {
301 uri: uri.clone(),
302 distance,
303 score: 1.0 - distance, metadata: None,
305 }));
306 }
307 }
308 }
309
310 let mut results: Vec<SearchResult> = heap.into_iter().map(|r| r.0).collect();
311 results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
312
313 Ok(results)
314 }
315
316 fn search_flat_parallel(
317 &self,
318 query: &Vector,
319 k: usize,
320 filter: Option<FilterFunctionSync>,
321 ) -> Result<Vec<SearchResult>> {
322 let chunk_size = (self.vectors.len() / num_threads()).max(100);
324
325 let filter_arc = filter.map(Arc::new);
327
328 let partial_results: Vec<Vec<SearchResult>> = self
330 .vectors
331 .par_chunks(chunk_size)
332 .map(|chunk| {
333 let mut local_heap = BinaryHeap::new();
334 let filter_ref = filter_arc.as_ref();
335
336 for (uri, vector) in chunk {
337 if let Some(filter_fn) = filter_ref {
338 if !filter_fn(uri) {
339 continue;
340 }
341 }
342
343 let distance = self.config.distance_metric.distance_vectors(query, vector);
344
345 if local_heap.len() < k {
346 local_heap.push(std::cmp::Reverse(SearchResult {
347 uri: uri.clone(),
348 distance,
349 score: 1.0 - distance, metadata: None,
351 }));
352 } else if let Some(std::cmp::Reverse(worst)) = local_heap.peek() {
353 if distance < worst.distance {
354 local_heap.pop();
355 local_heap.push(std::cmp::Reverse(SearchResult {
356 uri: uri.clone(),
357 distance,
358 score: 1.0 - distance, metadata: None,
360 }));
361 }
362 }
363 }
364
365 local_heap
366 .into_sorted_vec()
367 .into_iter()
368 .map(|r| r.0)
369 .collect()
370 })
371 .collect();
372
373 let mut final_heap = BinaryHeap::new();
375 for partial in partial_results {
376 for result in partial {
377 if final_heap.len() < k {
378 final_heap.push(std::cmp::Reverse(result));
379 } else if let Some(std::cmp::Reverse(worst)) = final_heap.peek() {
380 if result.distance < worst.distance {
381 final_heap.pop();
382 final_heap.push(std::cmp::Reverse(result));
383 }
384 }
385 }
386 }
387
388 let mut results: Vec<SearchResult> = final_heap.into_iter().map(|r| r.0).collect();
389 results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
390
391 Ok(results)
392 }
393
394 pub fn stats(&self) -> IndexStats {
396 IndexStats {
397 num_vectors: self.vectors.len(),
398 dimensions: self.dimensions.unwrap_or(0),
399 index_type: self.config.index_type,
400 memory_usage: self.estimate_memory_usage(),
401 }
402 }
403
404 fn estimate_memory_usage(&self) -> usize {
405 let vector_memory = self.vectors.len()
406 * (std::mem::size_of::<String>()
407 + self.dimensions.unwrap_or(0) * std::mem::size_of::<f32>());
408
409 let uri_map_memory =
410 self.uri_to_id.len() * (std::mem::size_of::<String>() + std::mem::size_of::<usize>());
411
412 vector_memory + uri_map_memory
413 }
414
415 pub fn len(&self) -> usize {
417 self.vectors.len()
418 }
419
420 pub fn is_empty(&self) -> bool {
422 self.vectors.is_empty()
423 }
424
425 pub fn add(
427 &mut self,
428 id: String,
429 vector: Vec<f32>,
430 _triple: Triple,
431 _metadata: HashMap<String, String>,
432 ) -> Result<()> {
433 let vector_obj = Vector::new(vector);
434 self.insert(id, vector_obj)
435 }
436
437 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
439 let query_vector = Vector::new(query.to_vec());
440 let results = self.search_advanced(&query_vector, k, None, None)?;
441 Ok(results)
442 }
443}
444
445impl VectorIndex for AdvancedVectorIndex {
446 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
447 if let Some(dims) = self.dimensions {
448 if vector.dimensions != dims {
449 return Err(anyhow!(
450 "Vector dimensions ({}) don't match index dimensions ({})",
451 vector.dimensions,
452 dims
453 ));
454 }
455 } else {
456 self.dimensions = Some(vector.dimensions);
457 }
458
459 let id = self.vectors.len();
460 self.uri_to_id.insert(uri.clone(), id);
461 self.vectors.push((uri, vector));
462
463 Ok(())
464 }
465
466 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
467 let results = self.search_advanced(query, k, None, None)?;
468 Ok(results.into_iter().map(|r| (r.uri, r.distance)).collect())
469 }
470
471 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
472 let mut results = Vec::new();
473
474 for (uri, vector) in &self.vectors {
475 let distance = self.config.distance_metric.distance_vectors(query, vector);
476 if distance <= threshold {
477 results.push((uri.clone(), distance));
478 }
479 }
480
481 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
482 Ok(results)
483 }
484
485 fn get_vector(&self, uri: &str) -> Option<&Vector> {
486 self.vectors.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
489 }
490}
491
492#[derive(Debug, Clone)]
494pub struct IndexStats {
495 pub num_vectors: usize,
496 pub dimensions: usize,
497 pub index_type: IndexType,
498 pub memory_usage: usize,
499}
500
501pub struct QuantizedVectorIndex {
503 config: IndexConfig,
504 quantized_vectors: Vec<Vec<u8>>,
505 centroids: Vec<Vector>,
506 uri_to_id: HashMap<String, usize>,
507 dimensions: Option<usize>,
508}
509
510impl QuantizedVectorIndex {
511 pub fn new(config: IndexConfig, num_centroids: usize) -> Self {
512 Self {
513 config,
514 quantized_vectors: Vec::new(),
515 centroids: Vec::with_capacity(num_centroids),
516 uri_to_id: HashMap::new(),
517 dimensions: None,
518 }
519 }
520
521 pub fn train_quantization(&mut self, training_vectors: &[Vector]) -> Result<()> {
523 if training_vectors.is_empty() {
524 return Err(anyhow!("No training vectors provided"));
525 }
526
527 let dimensions = training_vectors[0].dimensions;
528 self.dimensions = Some(dimensions);
529
530 self.centroids = kmeans_clustering(training_vectors, self.centroids.capacity())?;
532
533 Ok(())
534 }
535
536 fn quantize_vector(&self, vector: &Vector) -> Vec<u8> {
537 let mut quantized = Vec::new();
538
539 let chunk_size = vector.dimensions / self.centroids.len().max(1);
541
542 let vector_f32 = vector.as_f32();
543 for chunk in vector_f32.chunks(chunk_size) {
544 let mut best_centroid = 0u8;
545 let mut best_distance = f32::INFINITY;
546
547 for (i, centroid) in self.centroids.iter().enumerate() {
548 let centroid_f32 = centroid.as_f32();
549 let centroid_chunk = ¢roid_f32[0..chunk.len().min(centroid.dimensions)];
550 use oxirs_core::simd::SimdOps;
551 let distance = f32::euclidean_distance(chunk, centroid_chunk);
552 if distance < best_distance {
553 best_distance = distance;
554 best_centroid = i as u8;
555 }
556 }
557
558 quantized.push(best_centroid);
559 }
560
561 quantized
562 }
563}
564
565impl VectorIndex for QuantizedVectorIndex {
566 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
567 if self.centroids.is_empty() {
568 return Err(anyhow!(
569 "Quantization not trained. Call train_quantization first."
570 ));
571 }
572
573 let id = self.quantized_vectors.len();
574 self.uri_to_id.insert(uri.clone(), id);
575
576 let quantized = self.quantize_vector(&vector);
577 self.quantized_vectors.push(quantized);
578
579 Ok(())
580 }
581
582 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
583 let query_quantized = self.quantize_vector(query);
584 let mut results = Vec::new();
585
586 for (uri, quantized) in self.uri_to_id.keys().zip(&self.quantized_vectors) {
587 let distance = hamming_distance(&query_quantized, quantized);
588 results.push((uri.clone(), distance));
589 }
590
591 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
592 results.truncate(k);
593
594 Ok(results)
595 }
596
597 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
598 let query_quantized = self.quantize_vector(query);
599 let mut results = Vec::new();
600
601 for (uri, quantized) in self.uri_to_id.keys().zip(&self.quantized_vectors) {
602 let distance = hamming_distance(&query_quantized, quantized);
603 if distance <= threshold {
604 results.push((uri.clone(), distance));
605 }
606 }
607
608 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
609 Ok(results)
610 }
611
612 fn get_vector(&self, _uri: &str) -> Option<&Vector> {
613 None
616 }
617}
618
619fn hamming_distance(a: &[u8], b: &[u8]) -> f32 {
622 a.iter().zip(b).filter(|(x, y)| x != y).count() as f32
623}
624
625fn kmeans_clustering(vectors: &[Vector], k: usize) -> Result<Vec<Vector>> {
627 if vectors.is_empty() || k == 0 {
628 return Ok(Vec::new());
629 }
630
631 let dimensions = vectors[0].dimensions;
632 let mut centroids = Vec::with_capacity(k);
633
634 for i in 0..k {
636 let idx = i % vectors.len();
637 centroids.push(vectors[idx].clone());
638 }
639
640 for _ in 0..10 {
642 let mut clusters: Vec<Vec<&Vector>> = vec![Vec::new(); k];
643
644 for vector in vectors {
646 let mut best_centroid = 0;
647 let mut best_distance = f32::INFINITY;
648
649 for (i, centroid) in centroids.iter().enumerate() {
650 let vector_f32 = vector.as_f32();
651 let centroid_f32 = centroid.as_f32();
652 use oxirs_core::simd::SimdOps;
653 let distance = f32::euclidean_distance(&vector_f32, ¢roid_f32);
654 if distance < best_distance {
655 best_distance = distance;
656 best_centroid = i;
657 }
658 }
659
660 clusters[best_centroid].push(vector);
661 }
662
663 for (i, cluster) in clusters.iter().enumerate() {
665 if !cluster.is_empty() {
666 let mut new_centroid = vec![0.0; dimensions];
667
668 for vector in cluster {
669 let vector_f32 = vector.as_f32();
670 for (j, &value) in vector_f32.iter().enumerate() {
671 new_centroid[j] += value;
672 }
673 }
674
675 for value in &mut new_centroid {
676 *value /= cluster.len() as f32;
677 }
678
679 centroids[i] = Vector::new(new_centroid);
680 }
681 }
682 }
683
684 Ok(centroids)
685}
686
687pub struct MultiIndex {
689 indices: HashMap<String, Box<dyn VectorIndex>>,
690 default_index: String,
691}
692
693impl MultiIndex {
694 pub fn new() -> Self {
695 Self {
696 indices: HashMap::new(),
697 default_index: String::new(),
698 }
699 }
700
701 pub fn add_index(&mut self, name: String, index: Box<dyn VectorIndex>) {
702 if self.indices.is_empty() {
703 self.default_index = name.clone();
704 }
705 self.indices.insert(name, index);
706 }
707
708 pub fn set_default(&mut self, name: &str) -> Result<()> {
709 if self.indices.contains_key(name) {
710 self.default_index = name.to_string();
711 Ok(())
712 } else {
713 Err(anyhow!("Index '{}' not found", name))
714 }
715 }
716
717 pub fn search_index(
718 &self,
719 index_name: &str,
720 query: &Vector,
721 k: usize,
722 ) -> Result<Vec<(String, f32)>> {
723 if let Some(index) = self.indices.get(index_name) {
724 index.search_knn(query, k)
725 } else {
726 Err(anyhow!("Index '{}' not found", index_name))
727 }
728 }
729}
730
731impl Default for MultiIndex {
732 fn default() -> Self {
733 Self::new()
734 }
735}
736
737impl VectorIndex for MultiIndex {
738 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
739 if let Some(index) = self.indices.get_mut(&self.default_index) {
740 index.insert(uri, vector)
741 } else {
742 Err(anyhow!("No default index set"))
743 }
744 }
745
746 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
747 if let Some(index) = self.indices.get(&self.default_index) {
748 index.search_knn(query, k)
749 } else {
750 Err(anyhow!("No default index set"))
751 }
752 }
753
754 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
755 if let Some(index) = self.indices.get(&self.default_index) {
756 index.search_threshold(query, threshold)
757 } else {
758 Err(anyhow!("No default index set"))
759 }
760 }
761
762 fn get_vector(&self, uri: &str) -> Option<&Vector> {
763 if let Some(index) = self.indices.get(&self.default_index) {
764 index.get_vector(uri)
765 } else {
766 None
767 }
768 }
769}