1use crate::Vector;
4
5pub use crate::VectorIndex;
7use anyhow::{anyhow, Result};
8use oxirs_core::parallel::*;
9use oxirs_core::Triple;
10use serde::{Deserialize, Serialize};
11use std::cmp::Ordering;
12use std::collections::{BinaryHeap, HashMap};
13use std::sync::Arc;
14
15use crate::hnsw::{HnswConfig, HnswIndex};
16
17pub type FilterFunction = Box<dyn Fn(&str) -> bool>;
19pub type FilterFunctionSync = Box<dyn Fn(&str) -> bool + Send + Sync>;
21
22#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
24pub struct IndexConfig {
25 pub index_type: IndexType,
27 pub max_connections: usize,
29 pub ef_construction: usize,
31 pub ef_search: usize,
33 pub distance_metric: DistanceMetric,
35 pub parallel: bool,
37}
38
39impl Default for IndexConfig {
40 fn default() -> Self {
41 Self {
42 index_type: IndexType::Hnsw,
43 max_connections: 16,
44 ef_construction: 200,
45 ef_search: 50,
46 distance_metric: DistanceMetric::Cosine,
47 parallel: true,
48 }
49 }
50}
51
52#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
54pub enum IndexType {
55 Hnsw,
57 Flat,
59 Ivf,
61 PQ,
63}
64
65#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
67pub enum DistanceMetric {
68 Cosine,
70 Euclidean,
72 Manhattan,
74 DotProduct,
76}
77
78impl DistanceMetric {
79 pub fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
81 use oxirs_core::simd::SimdOps;
82
83 match self {
84 DistanceMetric::Cosine => f32::cosine_distance(a, b),
85 DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
86 DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
87 DistanceMetric::DotProduct => -f32::dot(a, b), }
89 }
90
91 pub fn distance_vectors(&self, a: &Vector, b: &Vector) -> f32 {
93 let a_f32 = a.as_f32();
94 let b_f32 = b.as_f32();
95 self.distance(&a_f32, &b_f32)
96 }
97}
98
99#[derive(Debug, Clone, PartialEq)]
101pub struct SearchResult {
102 pub uri: String,
103 pub distance: f32,
104 pub score: f32,
105 pub metadata: Option<HashMap<String, String>>,
106}
107
108impl Eq for SearchResult {}
109
110impl Ord for SearchResult {
111 fn cmp(&self, other: &Self) -> Ordering {
112 self.distance
113 .partial_cmp(&other.distance)
114 .unwrap_or(Ordering::Equal)
115 }
116}
117
118impl PartialOrd for SearchResult {
119 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
120 Some(self.cmp(other))
121 }
122}
123
124pub struct AdvancedVectorIndex {
126 config: IndexConfig,
127 vectors: Vec<(String, Vector)>,
128 uri_to_id: HashMap<String, usize>,
129 hnsw_index: Option<HnswIndex>,
130 dimensions: Option<usize>,
131}
132
133impl AdvancedVectorIndex {
134 pub fn new(config: IndexConfig) -> Self {
135 Self {
136 config,
137 vectors: Vec::new(),
138 uri_to_id: HashMap::new(),
139 hnsw_index: None,
140 dimensions: None,
141 }
142 }
143
144 pub fn build(&mut self) -> Result<()> {
146 if self.vectors.is_empty() {
147 return Ok(());
148 }
149
150 match self.config.index_type {
151 IndexType::Hnsw => {
152 self.build_hnsw_index()?;
153 }
154 IndexType::Flat => {
155 }
157 IndexType::Ivf | IndexType::PQ => {
158 return Err(anyhow!("IVF and PQ indices not yet implemented"));
159 }
160 }
161
162 Ok(())
163 }
164
165 fn build_hnsw_index(&mut self) -> Result<()> {
166 if self.dimensions.is_some() {
167 let hnsw_config = HnswConfig {
168 m: self.config.max_connections,
169 m_l0: self.config.max_connections * 2,
170 ef_construction: self.config.ef_construction,
171 ef: self.config.ef_search,
172 ..HnswConfig::default()
173 };
174
175 let mut hnsw = HnswIndex::new_cpu_only(hnsw_config);
176
177 for (uri, vector) in &self.vectors {
178 hnsw.insert(uri.clone(), vector.clone())?;
179 }
180
181 self.hnsw_index = Some(hnsw);
182 }
183
184 Ok(())
185 }
186
187 pub fn add_metadata(&mut self, _uri: &str, _metadata: HashMap<String, String>) -> Result<()> {
189 Ok(())
192 }
193
194 pub fn search_advanced(
196 &self,
197 query: &Vector,
198 k: usize,
199 _ef: Option<usize>,
200 filter: Option<FilterFunction>,
201 ) -> Result<Vec<SearchResult>> {
202 match self.config.index_type {
203 IndexType::Hnsw => self.search_hnsw(query, k),
204 _ => self.search_flat(query, k, filter),
205 }
206 }
207
208 fn search_hnsw(&self, query: &Vector, k: usize) -> Result<Vec<SearchResult>> {
209 if let Some(ref hnsw) = self.hnsw_index {
210 let results = hnsw.search_knn(query, k)?;
211
212 Ok(results
213 .into_iter()
214 .map(|(uri, distance)| SearchResult {
215 uri,
216 distance,
217 score: 1.0 - distance,
218 metadata: None,
219 })
220 .collect())
221 } else {
222 Err(anyhow!("HNSW index not built"))
223 }
224 }
225
226 fn search_flat(
227 &self,
228 query: &Vector,
229 k: usize,
230 filter: Option<FilterFunction>,
231 ) -> Result<Vec<SearchResult>> {
232 if self.config.parallel && self.vectors.len() > 1000 {
233 if filter.is_some() {
235 self.search_flat_sequential(query, k, filter)
237 } else {
238 self.search_flat_parallel(query, k, None)
239 }
240 } else {
241 self.search_flat_sequential(query, k, filter)
242 }
243 }
244
245 fn search_flat_sequential(
246 &self,
247 query: &Vector,
248 k: usize,
249 filter: Option<FilterFunction>,
250 ) -> Result<Vec<SearchResult>> {
251 let mut heap = BinaryHeap::new();
252
253 for (uri, vector) in &self.vectors {
254 if let Some(ref filter_fn) = filter {
255 if !filter_fn(uri) {
256 continue;
257 }
258 }
259
260 let distance = self.config.distance_metric.distance_vectors(query, vector);
261
262 if heap.len() < k {
263 heap.push(std::cmp::Reverse(SearchResult {
264 uri: uri.clone(),
265 distance,
266 score: 1.0 - distance, metadata: None,
268 }));
269 } else if let Some(std::cmp::Reverse(worst)) = heap.peek() {
270 if distance < worst.distance {
271 heap.pop();
272 heap.push(std::cmp::Reverse(SearchResult {
273 uri: uri.clone(),
274 distance,
275 score: 1.0 - distance, metadata: None,
277 }));
278 }
279 }
280 }
281
282 let mut results: Vec<SearchResult> = heap.into_iter().map(|r| r.0).collect();
283 results.sort_by(|a, b| {
284 a.distance
285 .partial_cmp(&b.distance)
286 .unwrap_or(std::cmp::Ordering::Equal)
287 });
288
289 Ok(results)
290 }
291
292 fn search_flat_parallel(
293 &self,
294 query: &Vector,
295 k: usize,
296 filter: Option<FilterFunctionSync>,
297 ) -> Result<Vec<SearchResult>> {
298 let chunk_size = (self.vectors.len() / num_threads()).max(100);
300
301 let filter_arc = filter.map(Arc::new);
303
304 let partial_results: Vec<Vec<SearchResult>> = self
306 .vectors
307 .par_chunks(chunk_size)
308 .map(|chunk| {
309 let mut local_heap = BinaryHeap::new();
310 let filter_ref = filter_arc.as_ref();
311
312 for (uri, vector) in chunk {
313 if let Some(filter_fn) = filter_ref {
314 if !filter_fn(uri) {
315 continue;
316 }
317 }
318
319 let distance = self.config.distance_metric.distance_vectors(query, vector);
320
321 if local_heap.len() < k {
322 local_heap.push(std::cmp::Reverse(SearchResult {
323 uri: uri.clone(),
324 distance,
325 score: 1.0 - distance, metadata: None,
327 }));
328 } else if let Some(std::cmp::Reverse(worst)) = local_heap.peek() {
329 if distance < worst.distance {
330 local_heap.pop();
331 local_heap.push(std::cmp::Reverse(SearchResult {
332 uri: uri.clone(),
333 distance,
334 score: 1.0 - distance, metadata: None,
336 }));
337 }
338 }
339 }
340
341 local_heap
342 .into_sorted_vec()
343 .into_iter()
344 .map(|r| r.0)
345 .collect()
346 })
347 .collect();
348
349 let mut final_heap = BinaryHeap::new();
351 for partial in partial_results {
352 for result in partial {
353 if final_heap.len() < k {
354 final_heap.push(std::cmp::Reverse(result));
355 } else if let Some(std::cmp::Reverse(worst)) = final_heap.peek() {
356 if result.distance < worst.distance {
357 final_heap.pop();
358 final_heap.push(std::cmp::Reverse(result));
359 }
360 }
361 }
362 }
363
364 let mut results: Vec<SearchResult> = final_heap.into_iter().map(|r| r.0).collect();
365 results.sort_by(|a, b| {
366 a.distance
367 .partial_cmp(&b.distance)
368 .unwrap_or(std::cmp::Ordering::Equal)
369 });
370
371 Ok(results)
372 }
373
374 pub fn stats(&self) -> IndexStats {
376 IndexStats {
377 num_vectors: self.vectors.len(),
378 dimensions: self.dimensions.unwrap_or(0),
379 index_type: self.config.index_type,
380 memory_usage: self.estimate_memory_usage(),
381 }
382 }
383
384 fn estimate_memory_usage(&self) -> usize {
385 let vector_memory = self.vectors.len()
386 * (std::mem::size_of::<String>()
387 + self.dimensions.unwrap_or(0) * std::mem::size_of::<f32>());
388
389 let uri_map_memory =
390 self.uri_to_id.len() * (std::mem::size_of::<String>() + std::mem::size_of::<usize>());
391
392 vector_memory + uri_map_memory
393 }
394
395 pub fn len(&self) -> usize {
397 self.vectors.len()
398 }
399
400 pub fn is_empty(&self) -> bool {
402 self.vectors.is_empty()
403 }
404
405 pub fn add(
407 &mut self,
408 id: String,
409 vector: Vec<f32>,
410 _triple: Triple,
411 _metadata: HashMap<String, String>,
412 ) -> Result<()> {
413 let vector_obj = Vector::new(vector);
414 self.insert(id, vector_obj)
415 }
416
417 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
419 let query_vector = Vector::new(query.to_vec());
420 let results = self.search_advanced(&query_vector, k, None, None)?;
421 Ok(results)
422 }
423}
424
425impl VectorIndex for AdvancedVectorIndex {
426 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
427 if let Some(dims) = self.dimensions {
428 if vector.dimensions != dims {
429 return Err(anyhow!(
430 "Vector dimensions ({}) don't match index dimensions ({})",
431 vector.dimensions,
432 dims
433 ));
434 }
435 } else {
436 self.dimensions = Some(vector.dimensions);
437 }
438
439 let id = self.vectors.len();
440 self.uri_to_id.insert(uri.clone(), id);
441 self.vectors.push((uri, vector));
442
443 Ok(())
444 }
445
446 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
447 let results = self.search_advanced(query, k, None, None)?;
448 Ok(results.into_iter().map(|r| (r.uri, r.distance)).collect())
449 }
450
451 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
452 let mut results = Vec::new();
453
454 for (uri, vector) in &self.vectors {
455 let distance = self.config.distance_metric.distance_vectors(query, vector);
456 if distance <= threshold {
457 results.push((uri.clone(), distance));
458 }
459 }
460
461 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
462 Ok(results)
463 }
464
465 fn get_vector(&self, uri: &str) -> Option<&Vector> {
466 self.vectors.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
469 }
470}
471
472#[derive(Debug, Clone)]
474pub struct IndexStats {
475 pub num_vectors: usize,
476 pub dimensions: usize,
477 pub index_type: IndexType,
478 pub memory_usage: usize,
479}
480
481pub struct QuantizedVectorIndex {
483 config: IndexConfig,
484 quantized_vectors: Vec<Vec<u8>>,
485 centroids: Vec<Vector>,
486 uri_to_id: HashMap<String, usize>,
487 dimensions: Option<usize>,
488}
489
490impl QuantizedVectorIndex {
491 pub fn new(config: IndexConfig, num_centroids: usize) -> Self {
492 Self {
493 config,
494 quantized_vectors: Vec::new(),
495 centroids: Vec::with_capacity(num_centroids),
496 uri_to_id: HashMap::new(),
497 dimensions: None,
498 }
499 }
500
501 pub fn train_quantization(&mut self, training_vectors: &[Vector]) -> Result<()> {
503 if training_vectors.is_empty() {
504 return Err(anyhow!("No training vectors provided"));
505 }
506
507 let dimensions = training_vectors[0].dimensions;
508 self.dimensions = Some(dimensions);
509
510 self.centroids = kmeans_clustering(training_vectors, self.centroids.capacity())?;
512
513 Ok(())
514 }
515
516 fn quantize_vector(&self, vector: &Vector) -> Vec<u8> {
517 let mut quantized = Vec::new();
518
519 let chunk_size = vector.dimensions / self.centroids.len().max(1);
521
522 let vector_f32 = vector.as_f32();
523 for chunk in vector_f32.chunks(chunk_size) {
524 let mut best_centroid = 0u8;
525 let mut best_distance = f32::INFINITY;
526
527 for (i, centroid) in self.centroids.iter().enumerate() {
528 let centroid_f32 = centroid.as_f32();
529 let centroid_chunk = ¢roid_f32[0..chunk.len().min(centroid.dimensions)];
530 use oxirs_core::simd::SimdOps;
531 let distance = f32::euclidean_distance(chunk, centroid_chunk);
532 if distance < best_distance {
533 best_distance = distance;
534 best_centroid = i as u8;
535 }
536 }
537
538 quantized.push(best_centroid);
539 }
540
541 quantized
542 }
543}
544
545impl VectorIndex for QuantizedVectorIndex {
546 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
547 if self.centroids.is_empty() {
548 return Err(anyhow!(
549 "Quantization not trained. Call train_quantization first."
550 ));
551 }
552
553 let id = self.quantized_vectors.len();
554 self.uri_to_id.insert(uri.clone(), id);
555
556 let quantized = self.quantize_vector(&vector);
557 self.quantized_vectors.push(quantized);
558
559 Ok(())
560 }
561
562 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
563 let query_quantized = self.quantize_vector(query);
564 let mut results = Vec::new();
565
566 for (uri, quantized) in self.uri_to_id.keys().zip(&self.quantized_vectors) {
567 let distance = hamming_distance(&query_quantized, quantized);
568 results.push((uri.clone(), distance));
569 }
570
571 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
572 results.truncate(k);
573
574 Ok(results)
575 }
576
577 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
578 let query_quantized = self.quantize_vector(query);
579 let mut results = Vec::new();
580
581 for (uri, quantized) in self.uri_to_id.keys().zip(&self.quantized_vectors) {
582 let distance = hamming_distance(&query_quantized, quantized);
583 if distance <= threshold {
584 results.push((uri.clone(), distance));
585 }
586 }
587
588 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
589 Ok(results)
590 }
591
592 fn get_vector(&self, _uri: &str) -> Option<&Vector> {
593 None
596 }
597}
598
599fn hamming_distance(a: &[u8], b: &[u8]) -> f32 {
602 a.iter().zip(b).filter(|(x, y)| x != y).count() as f32
603}
604
605fn kmeans_clustering(vectors: &[Vector], k: usize) -> Result<Vec<Vector>> {
607 if vectors.is_empty() || k == 0 {
608 return Ok(Vec::new());
609 }
610
611 let dimensions = vectors[0].dimensions;
612 let mut centroids = Vec::with_capacity(k);
613
614 for i in 0..k {
616 let idx = i % vectors.len();
617 centroids.push(vectors[idx].clone());
618 }
619
620 for _ in 0..10 {
622 let mut clusters: Vec<Vec<&Vector>> = vec![Vec::new(); k];
623
624 for vector in vectors {
626 let mut best_centroid = 0;
627 let mut best_distance = f32::INFINITY;
628
629 for (i, centroid) in centroids.iter().enumerate() {
630 let vector_f32 = vector.as_f32();
631 let centroid_f32 = centroid.as_f32();
632 use oxirs_core::simd::SimdOps;
633 let distance = f32::euclidean_distance(&vector_f32, ¢roid_f32);
634 if distance < best_distance {
635 best_distance = distance;
636 best_centroid = i;
637 }
638 }
639
640 clusters[best_centroid].push(vector);
641 }
642
643 for (i, cluster) in clusters.iter().enumerate() {
645 if !cluster.is_empty() {
646 let mut new_centroid = vec![0.0; dimensions];
647
648 for vector in cluster {
649 let vector_f32 = vector.as_f32();
650 for (j, &value) in vector_f32.iter().enumerate() {
651 new_centroid[j] += value;
652 }
653 }
654
655 for value in &mut new_centroid {
656 *value /= cluster.len() as f32;
657 }
658
659 centroids[i] = Vector::new(new_centroid);
660 }
661 }
662 }
663
664 Ok(centroids)
665}
666
667pub struct MultiIndex {
669 indices: HashMap<String, Box<dyn VectorIndex>>,
670 default_index: String,
671}
672
673impl MultiIndex {
674 pub fn new() -> Self {
675 Self {
676 indices: HashMap::new(),
677 default_index: String::new(),
678 }
679 }
680
681 pub fn add_index(&mut self, name: String, index: Box<dyn VectorIndex>) {
682 if self.indices.is_empty() {
683 self.default_index = name.clone();
684 }
685 self.indices.insert(name, index);
686 }
687
688 pub fn set_default(&mut self, name: &str) -> Result<()> {
689 if self.indices.contains_key(name) {
690 self.default_index = name.to_string();
691 Ok(())
692 } else {
693 Err(anyhow!("Index '{}' not found", name))
694 }
695 }
696
697 pub fn search_index(
698 &self,
699 index_name: &str,
700 query: &Vector,
701 k: usize,
702 ) -> Result<Vec<(String, f32)>> {
703 if let Some(index) = self.indices.get(index_name) {
704 index.search_knn(query, k)
705 } else {
706 Err(anyhow!("Index '{}' not found", index_name))
707 }
708 }
709}
710
711impl Default for MultiIndex {
712 fn default() -> Self {
713 Self::new()
714 }
715}
716
717impl VectorIndex for MultiIndex {
718 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
719 if let Some(index) = self.indices.get_mut(&self.default_index) {
720 index.insert(uri, vector)
721 } else {
722 Err(anyhow!("No default index set"))
723 }
724 }
725
726 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
727 if let Some(index) = self.indices.get(&self.default_index) {
728 index.search_knn(query, k)
729 } else {
730 Err(anyhow!("No default index set"))
731 }
732 }
733
734 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
735 if let Some(index) = self.indices.get(&self.default_index) {
736 index.search_threshold(query, threshold)
737 } else {
738 Err(anyhow!("No default index set"))
739 }
740 }
741
742 fn get_vector(&self, uri: &str) -> Option<&Vector> {
743 if let Some(index) = self.indices.get(&self.default_index) {
744 index.get_vector(uri)
745 } else {
746 None
747 }
748 }
749}