1use crate::Vector;
4
5pub use crate::VectorIndex;
7use anyhow::{anyhow, Result};
8use oxirs_core::parallel::*;
9use oxirs_core::Triple;
10use serde::{Deserialize, Serialize};
11use std::cmp::Ordering;
12use std::collections::{BinaryHeap, HashMap};
13use std::sync::Arc;
14
15#[cfg(feature = "hnsw")]
16use hnsw_rs::prelude::*;
17
18pub type FilterFunction = Box<dyn Fn(&str) -> bool>;
20pub type FilterFunctionSync = Box<dyn Fn(&str) -> bool + Send + Sync>;
22
23#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
25pub struct IndexConfig {
26 pub index_type: IndexType,
28 pub max_connections: usize,
30 pub ef_construction: usize,
32 pub ef_search: usize,
34 pub distance_metric: DistanceMetric,
36 pub parallel: bool,
38}
39
40impl Default for IndexConfig {
41 fn default() -> Self {
42 Self {
43 index_type: IndexType::Hnsw,
44 max_connections: 16,
45 ef_construction: 200,
46 ef_search: 50,
47 distance_metric: DistanceMetric::Cosine,
48 parallel: true,
49 }
50 }
51}
52
53#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
55pub enum IndexType {
56 Hnsw,
58 Flat,
60 Ivf,
62 PQ,
64}
65
66#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
68pub enum DistanceMetric {
69 Cosine,
71 Euclidean,
73 Manhattan,
75 DotProduct,
77}
78
79impl DistanceMetric {
80 pub fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
82 use oxirs_core::simd::SimdOps;
83
84 match self {
85 DistanceMetric::Cosine => f32::cosine_distance(a, b),
86 DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
87 DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
88 DistanceMetric::DotProduct => -f32::dot(a, b), }
90 }
91
92 pub fn distance_vectors(&self, a: &Vector, b: &Vector) -> f32 {
94 let a_f32 = a.as_f32();
95 let b_f32 = b.as_f32();
96 self.distance(&a_f32, &b_f32)
97 }
98}
99
100#[derive(Debug, Clone, PartialEq)]
102pub struct SearchResult {
103 pub uri: String,
104 pub distance: f32,
105 pub score: f32,
106 pub metadata: Option<HashMap<String, String>>,
107}
108
109impl Eq for SearchResult {}
110
111impl Ord for SearchResult {
112 fn cmp(&self, other: &Self) -> Ordering {
113 self.distance
114 .partial_cmp(&other.distance)
115 .unwrap_or(Ordering::Equal)
116 }
117}
118
119impl PartialOrd for SearchResult {
120 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
121 Some(self.cmp(other))
122 }
123}
124
125pub struct AdvancedVectorIndex {
127 config: IndexConfig,
128 vectors: Vec<(String, Vector)>,
129 uri_to_id: HashMap<String, usize>,
130 #[cfg(feature = "hnsw")]
131 hnsw_index: Option<Hnsw<'static, f32, DistCosine>>,
132 dimensions: Option<usize>,
133}
134
135impl AdvancedVectorIndex {
136 pub fn new(config: IndexConfig) -> Self {
137 Self {
138 config,
139 vectors: Vec::new(),
140 uri_to_id: HashMap::new(),
141 #[cfg(feature = "hnsw")]
142 hnsw_index: None,
143 dimensions: None,
144 }
145 }
146
147 pub fn build(&mut self) -> Result<()> {
149 if self.vectors.is_empty() {
150 return Ok(());
151 }
152
153 match self.config.index_type {
154 IndexType::Hnsw => {
155 #[cfg(feature = "hnsw")]
156 {
157 self.build_hnsw_index()?;
158 }
159 #[cfg(not(feature = "hnsw"))]
160 {
161 return Err(anyhow!("HNSW feature not enabled"));
162 }
163 }
164 IndexType::Flat => {
165 }
167 IndexType::Ivf | IndexType::PQ => {
168 return Err(anyhow!("IVF and PQ indices not yet implemented"));
169 }
170 }
171
172 Ok(())
173 }
174
175 #[cfg(feature = "hnsw")]
176 fn build_hnsw_index(&mut self) -> Result<()> {
177 if let Some(dimensions) = self.dimensions {
178 let mut hnsw = Hnsw::<f32, DistCosine>::new(
179 self.config.max_connections,
180 self.vectors.len(),
181 16, self.config.ef_construction,
183 DistCosine,
184 );
185
186 for (id, (_, vector)) in self.vectors.iter().enumerate() {
187 let vector_f32 = vector.as_f32();
188 hnsw.insert((&vector_f32, id));
189 }
190
191 self.hnsw_index = Some(hnsw);
192 }
193
194 Ok(())
195 }
196
197 pub fn add_metadata(&mut self, _uri: &str, _metadata: HashMap<String, String>) -> Result<()> {
199 Ok(())
202 }
203
204 pub fn search_advanced(
206 &self,
207 query: &Vector,
208 k: usize,
209 _ef: Option<usize>,
210 filter: Option<FilterFunction>,
211 ) -> Result<Vec<SearchResult>> {
212 match self.config.index_type {
213 IndexType::Hnsw => {
214 #[cfg(feature = "hnsw")]
215 {
216 self.search_hnsw(query, k, ef)
217 }
218 #[cfg(not(feature = "hnsw"))]
219 {
220 self.search_flat(query, k, filter)
221 }
222 }
223 _ => self.search_flat(query, k, filter),
224 }
225 }
226
227 #[cfg(feature = "hnsw")]
228 fn search_hnsw(
229 &self,
230 query: &Vector,
231 k: usize,
232 ef: Option<usize>,
233 ) -> Result<Vec<SearchResult>> {
234 if let Some(ref hnsw) = self.hnsw_index {
235 let search_ef = ef.unwrap_or(self.config.ef_search);
236 let query_f32 = query.as_f32();
237 let results = hnsw.search(&query_f32, k, search_ef);
238
239 Ok(results
240 .into_iter()
241 .map(|result| SearchResult {
242 uri: self.vectors[result.d_id].0.clone(),
243 distance: result.distance,
244 score: 1.0 - result.distance, metadata: None,
246 })
247 .collect())
248 } else {
249 Err(anyhow!("HNSW index not built"))
250 }
251 }
252
253 fn search_flat(
254 &self,
255 query: &Vector,
256 k: usize,
257 filter: Option<FilterFunction>,
258 ) -> Result<Vec<SearchResult>> {
259 if self.config.parallel && self.vectors.len() > 1000 {
260 if filter.is_some() {
262 self.search_flat_sequential(query, k, filter)
264 } else {
265 self.search_flat_parallel(query, k, None)
266 }
267 } else {
268 self.search_flat_sequential(query, k, filter)
269 }
270 }
271
272 fn search_flat_sequential(
273 &self,
274 query: &Vector,
275 k: usize,
276 filter: Option<FilterFunction>,
277 ) -> Result<Vec<SearchResult>> {
278 let mut heap = BinaryHeap::new();
279
280 for (uri, vector) in &self.vectors {
281 if let Some(ref filter_fn) = filter {
282 if !filter_fn(uri) {
283 continue;
284 }
285 }
286
287 let distance = self.config.distance_metric.distance_vectors(query, vector);
288
289 if heap.len() < k {
290 heap.push(std::cmp::Reverse(SearchResult {
291 uri: uri.clone(),
292 distance,
293 score: 1.0 - distance, metadata: None,
295 }));
296 } else if let Some(std::cmp::Reverse(worst)) = heap.peek() {
297 if distance < worst.distance {
298 heap.pop();
299 heap.push(std::cmp::Reverse(SearchResult {
300 uri: uri.clone(),
301 distance,
302 score: 1.0 - distance, metadata: None,
304 }));
305 }
306 }
307 }
308
309 let mut results: Vec<SearchResult> = heap.into_iter().map(|r| r.0).collect();
310 results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
311
312 Ok(results)
313 }
314
315 fn search_flat_parallel(
316 &self,
317 query: &Vector,
318 k: usize,
319 filter: Option<FilterFunctionSync>,
320 ) -> Result<Vec<SearchResult>> {
321 let chunk_size = (self.vectors.len() / num_threads()).max(100);
323
324 let filter_arc = filter.map(Arc::new);
326
327 let partial_results: Vec<Vec<SearchResult>> = self
329 .vectors
330 .par_chunks(chunk_size)
331 .map(|chunk| {
332 let mut local_heap = BinaryHeap::new();
333 let filter_ref = filter_arc.as_ref();
334
335 for (uri, vector) in chunk {
336 if let Some(filter_fn) = filter_ref {
337 if !filter_fn(uri) {
338 continue;
339 }
340 }
341
342 let distance = self.config.distance_metric.distance_vectors(query, vector);
343
344 if local_heap.len() < k {
345 local_heap.push(std::cmp::Reverse(SearchResult {
346 uri: uri.clone(),
347 distance,
348 score: 1.0 - distance, metadata: None,
350 }));
351 } else if let Some(std::cmp::Reverse(worst)) = local_heap.peek() {
352 if distance < worst.distance {
353 local_heap.pop();
354 local_heap.push(std::cmp::Reverse(SearchResult {
355 uri: uri.clone(),
356 distance,
357 score: 1.0 - distance, metadata: None,
359 }));
360 }
361 }
362 }
363
364 local_heap
365 .into_sorted_vec()
366 .into_iter()
367 .map(|r| r.0)
368 .collect()
369 })
370 .collect();
371
372 let mut final_heap = BinaryHeap::new();
374 for partial in partial_results {
375 for result in partial {
376 if final_heap.len() < k {
377 final_heap.push(std::cmp::Reverse(result));
378 } else if let Some(std::cmp::Reverse(worst)) = final_heap.peek() {
379 if result.distance < worst.distance {
380 final_heap.pop();
381 final_heap.push(std::cmp::Reverse(result));
382 }
383 }
384 }
385 }
386
387 let mut results: Vec<SearchResult> = final_heap.into_iter().map(|r| r.0).collect();
388 results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
389
390 Ok(results)
391 }
392
393 pub fn stats(&self) -> IndexStats {
395 IndexStats {
396 num_vectors: self.vectors.len(),
397 dimensions: self.dimensions.unwrap_or(0),
398 index_type: self.config.index_type,
399 memory_usage: self.estimate_memory_usage(),
400 }
401 }
402
403 fn estimate_memory_usage(&self) -> usize {
404 let vector_memory = self.vectors.len()
405 * (std::mem::size_of::<String>()
406 + self.dimensions.unwrap_or(0) * std::mem::size_of::<f32>());
407
408 let uri_map_memory =
409 self.uri_to_id.len() * (std::mem::size_of::<String>() + std::mem::size_of::<usize>());
410
411 vector_memory + uri_map_memory
412 }
413
414 pub fn len(&self) -> usize {
416 self.vectors.len()
417 }
418
419 pub fn is_empty(&self) -> bool {
421 self.vectors.is_empty()
422 }
423
424 pub fn add(
426 &mut self,
427 id: String,
428 vector: Vec<f32>,
429 _triple: Triple,
430 _metadata: HashMap<String, String>,
431 ) -> Result<()> {
432 let vector_obj = Vector::new(vector);
433 self.insert(id, vector_obj)
434 }
435
436 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
438 let query_vector = Vector::new(query.to_vec());
439 let results = self.search_advanced(&query_vector, k, None, None)?;
440 Ok(results)
441 }
442}
443
444impl VectorIndex for AdvancedVectorIndex {
445 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
446 if let Some(dims) = self.dimensions {
447 if vector.dimensions != dims {
448 return Err(anyhow!(
449 "Vector dimensions ({}) don't match index dimensions ({})",
450 vector.dimensions,
451 dims
452 ));
453 }
454 } else {
455 self.dimensions = Some(vector.dimensions);
456 }
457
458 let id = self.vectors.len();
459 self.uri_to_id.insert(uri.clone(), id);
460 self.vectors.push((uri, vector));
461
462 Ok(())
463 }
464
465 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
466 let results = self.search_advanced(query, k, None, None)?;
467 Ok(results.into_iter().map(|r| (r.uri, r.distance)).collect())
468 }
469
470 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
471 let mut results = Vec::new();
472
473 for (uri, vector) in &self.vectors {
474 let distance = self.config.distance_metric.distance_vectors(query, vector);
475 if distance <= threshold {
476 results.push((uri.clone(), distance));
477 }
478 }
479
480 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
481 Ok(results)
482 }
483
484 fn get_vector(&self, uri: &str) -> Option<&Vector> {
485 self.vectors.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
488 }
489}
490
491#[derive(Debug, Clone)]
493pub struct IndexStats {
494 pub num_vectors: usize,
495 pub dimensions: usize,
496 pub index_type: IndexType,
497 pub memory_usage: usize,
498}
499
500pub struct QuantizedVectorIndex {
502 config: IndexConfig,
503 quantized_vectors: Vec<Vec<u8>>,
504 centroids: Vec<Vector>,
505 uri_to_id: HashMap<String, usize>,
506 dimensions: Option<usize>,
507}
508
509impl QuantizedVectorIndex {
510 pub fn new(config: IndexConfig, num_centroids: usize) -> Self {
511 Self {
512 config,
513 quantized_vectors: Vec::new(),
514 centroids: Vec::with_capacity(num_centroids),
515 uri_to_id: HashMap::new(),
516 dimensions: None,
517 }
518 }
519
520 pub fn train_quantization(&mut self, training_vectors: &[Vector]) -> Result<()> {
522 if training_vectors.is_empty() {
523 return Err(anyhow!("No training vectors provided"));
524 }
525
526 let dimensions = training_vectors[0].dimensions;
527 self.dimensions = Some(dimensions);
528
529 self.centroids = kmeans_clustering(training_vectors, self.centroids.capacity())?;
531
532 Ok(())
533 }
534
535 fn quantize_vector(&self, vector: &Vector) -> Vec<u8> {
536 let mut quantized = Vec::new();
537
538 let chunk_size = vector.dimensions / self.centroids.len().max(1);
540
541 let vector_f32 = vector.as_f32();
542 for chunk in vector_f32.chunks(chunk_size) {
543 let mut best_centroid = 0u8;
544 let mut best_distance = f32::INFINITY;
545
546 for (i, centroid) in self.centroids.iter().enumerate() {
547 let centroid_f32 = centroid.as_f32();
548 let centroid_chunk = ¢roid_f32[0..chunk.len().min(centroid.dimensions)];
549 use oxirs_core::simd::SimdOps;
550 let distance = f32::euclidean_distance(chunk, centroid_chunk);
551 if distance < best_distance {
552 best_distance = distance;
553 best_centroid = i as u8;
554 }
555 }
556
557 quantized.push(best_centroid);
558 }
559
560 quantized
561 }
562}
563
564impl VectorIndex for QuantizedVectorIndex {
565 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
566 if self.centroids.is_empty() {
567 return Err(anyhow!(
568 "Quantization not trained. Call train_quantization first."
569 ));
570 }
571
572 let id = self.quantized_vectors.len();
573 self.uri_to_id.insert(uri.clone(), id);
574
575 let quantized = self.quantize_vector(&vector);
576 self.quantized_vectors.push(quantized);
577
578 Ok(())
579 }
580
581 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
582 let query_quantized = self.quantize_vector(query);
583 let mut results = Vec::new();
584
585 for (uri, quantized) in self.uri_to_id.keys().zip(&self.quantized_vectors) {
586 let distance = hamming_distance(&query_quantized, quantized);
587 results.push((uri.clone(), distance));
588 }
589
590 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
591 results.truncate(k);
592
593 Ok(results)
594 }
595
596 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
597 let query_quantized = self.quantize_vector(query);
598 let mut results = Vec::new();
599
600 for (uri, quantized) in self.uri_to_id.keys().zip(&self.quantized_vectors) {
601 let distance = hamming_distance(&query_quantized, quantized);
602 if distance <= threshold {
603 results.push((uri.clone(), distance));
604 }
605 }
606
607 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
608 Ok(results)
609 }
610
611 fn get_vector(&self, _uri: &str) -> Option<&Vector> {
612 None
615 }
616}
617
618fn hamming_distance(a: &[u8], b: &[u8]) -> f32 {
621 a.iter().zip(b).filter(|(x, y)| x != y).count() as f32
622}
623
624fn kmeans_clustering(vectors: &[Vector], k: usize) -> Result<Vec<Vector>> {
626 if vectors.is_empty() || k == 0 {
627 return Ok(Vec::new());
628 }
629
630 let dimensions = vectors[0].dimensions;
631 let mut centroids = Vec::with_capacity(k);
632
633 for i in 0..k {
635 let idx = i % vectors.len();
636 centroids.push(vectors[idx].clone());
637 }
638
639 for _ in 0..10 {
641 let mut clusters: Vec<Vec<&Vector>> = vec![Vec::new(); k];
642
643 for vector in vectors {
645 let mut best_centroid = 0;
646 let mut best_distance = f32::INFINITY;
647
648 for (i, centroid) in centroids.iter().enumerate() {
649 let vector_f32 = vector.as_f32();
650 let centroid_f32 = centroid.as_f32();
651 use oxirs_core::simd::SimdOps;
652 let distance = f32::euclidean_distance(&vector_f32, ¢roid_f32);
653 if distance < best_distance {
654 best_distance = distance;
655 best_centroid = i;
656 }
657 }
658
659 clusters[best_centroid].push(vector);
660 }
661
662 for (i, cluster) in clusters.iter().enumerate() {
664 if !cluster.is_empty() {
665 let mut new_centroid = vec![0.0; dimensions];
666
667 for vector in cluster {
668 let vector_f32 = vector.as_f32();
669 for (j, &value) in vector_f32.iter().enumerate() {
670 new_centroid[j] += value;
671 }
672 }
673
674 for value in &mut new_centroid {
675 *value /= cluster.len() as f32;
676 }
677
678 centroids[i] = Vector::new(new_centroid);
679 }
680 }
681 }
682
683 Ok(centroids)
684}
685
686pub struct MultiIndex {
688 indices: HashMap<String, Box<dyn VectorIndex>>,
689 default_index: String,
690}
691
692impl MultiIndex {
693 pub fn new() -> Self {
694 Self {
695 indices: HashMap::new(),
696 default_index: String::new(),
697 }
698 }
699
700 pub fn add_index(&mut self, name: String, index: Box<dyn VectorIndex>) {
701 if self.indices.is_empty() {
702 self.default_index = name.clone();
703 }
704 self.indices.insert(name, index);
705 }
706
707 pub fn set_default(&mut self, name: &str) -> Result<()> {
708 if self.indices.contains_key(name) {
709 self.default_index = name.to_string();
710 Ok(())
711 } else {
712 Err(anyhow!("Index '{}' not found", name))
713 }
714 }
715
716 pub fn search_index(
717 &self,
718 index_name: &str,
719 query: &Vector,
720 k: usize,
721 ) -> Result<Vec<(String, f32)>> {
722 if let Some(index) = self.indices.get(index_name) {
723 index.search_knn(query, k)
724 } else {
725 Err(anyhow!("Index '{}' not found", index_name))
726 }
727 }
728}
729
730impl Default for MultiIndex {
731 fn default() -> Self {
732 Self::new()
733 }
734}
735
736impl VectorIndex for MultiIndex {
737 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
738 if let Some(index) = self.indices.get_mut(&self.default_index) {
739 index.insert(uri, vector)
740 } else {
741 Err(anyhow!("No default index set"))
742 }
743 }
744
745 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
746 if let Some(index) = self.indices.get(&self.default_index) {
747 index.search_knn(query, k)
748 } else {
749 Err(anyhow!("No default index set"))
750 }
751 }
752
753 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
754 if let Some(index) = self.indices.get(&self.default_index) {
755 index.search_threshold(query, threshold)
756 } else {
757 Err(anyhow!("No default index set"))
758 }
759 }
760
761 fn get_vector(&self, uri: &str) -> Option<&Vector> {
762 if let Some(index) = self.indices.get(&self.default_index) {
763 index.get_vector(uri)
764 } else {
765 None
766 }
767 }
768}