1use crate::Vector;
4
5pub use crate::VectorIndex;
7use anyhow::{anyhow, Result};
8use oxirs_core::parallel::*;
9use oxirs_core::Triple;
10use serde::{Deserialize, Serialize};
11use std::cmp::Ordering;
12use std::collections::{BinaryHeap, HashMap};
13use std::sync::Arc;
14
15use crate::hnsw::{HnswConfig, HnswIndex};
16use crate::ivf::{IvfConfig, IvfIndex};
17use crate::pq::{PQConfig, PQIndex};
18
19pub type FilterFunction = Box<dyn Fn(&str) -> bool>;
21pub type FilterFunctionSync = Box<dyn Fn(&str) -> bool + Send + Sync>;
23
24#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
26pub struct IndexConfig {
27 pub index_type: IndexType,
29 pub max_connections: usize,
31 pub ef_construction: usize,
33 pub ef_search: usize,
35 pub distance_metric: DistanceMetric,
37 pub parallel: bool,
39}
40
41impl Default for IndexConfig {
42 fn default() -> Self {
43 Self {
44 index_type: IndexType::Hnsw,
45 max_connections: 16,
46 ef_construction: 200,
47 ef_search: 50,
48 distance_metric: DistanceMetric::Cosine,
49 parallel: true,
50 }
51 }
52}
53
54#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
56pub enum IndexType {
57 Hnsw,
59 Flat,
61 Ivf,
63 PQ,
65}
66
67#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
69pub enum DistanceMetric {
70 Cosine,
72 Euclidean,
74 Manhattan,
76 DotProduct,
78}
79
80impl DistanceMetric {
81 pub fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
83 use oxirs_core::simd::SimdOps;
84
85 match self {
86 DistanceMetric::Cosine => f32::cosine_distance(a, b),
87 DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
88 DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
89 DistanceMetric::DotProduct => -f32::dot(a, b), }
91 }
92
93 pub fn distance_vectors(&self, a: &Vector, b: &Vector) -> f32 {
95 let a_f32 = a.as_f32();
96 let b_f32 = b.as_f32();
97 self.distance(&a_f32, &b_f32)
98 }
99}
100
101#[derive(Debug, Clone, PartialEq)]
103pub struct SearchResult {
104 pub uri: String,
105 pub distance: f32,
106 pub score: f32,
107 pub metadata: Option<HashMap<String, String>>,
108}
109
110impl Eq for SearchResult {}
111
112impl Ord for SearchResult {
113 fn cmp(&self, other: &Self) -> Ordering {
114 self.distance
115 .partial_cmp(&other.distance)
116 .unwrap_or(Ordering::Equal)
117 }
118}
119
120impl PartialOrd for SearchResult {
121 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
122 Some(self.cmp(other))
123 }
124}
125
126pub struct AdvancedVectorIndex {
128 config: IndexConfig,
129 vectors: Vec<(String, Vector)>,
130 uri_to_id: HashMap<String, usize>,
131 hnsw_index: Option<HnswIndex>,
132 ivf_index: Option<IvfIndex>,
134 pq_index: Option<PQIndex>,
136 dimensions: Option<usize>,
137}
138
139impl AdvancedVectorIndex {
140 pub fn new(config: IndexConfig) -> Self {
141 Self {
142 config,
143 vectors: Vec::new(),
144 uri_to_id: HashMap::new(),
145 hnsw_index: None,
146 ivf_index: None,
147 pq_index: None,
148 dimensions: None,
149 }
150 }
151
152 pub fn build(&mut self) -> Result<()> {
154 if self.vectors.is_empty() {
155 return Ok(());
156 }
157
158 match self.config.index_type {
159 IndexType::Hnsw => {
160 self.build_hnsw_index()?;
161 }
162 IndexType::Flat => {
163 }
165 IndexType::Ivf => {
166 self.build_ivf_index()?;
167 }
168 IndexType::PQ => {
169 self.build_pq_index()?;
170 }
171 }
172
173 Ok(())
174 }
175
176 fn build_hnsw_index(&mut self) -> Result<()> {
177 if self.dimensions.is_some() {
178 let hnsw_config = HnswConfig {
179 m: self.config.max_connections,
180 m_l0: self.config.max_connections * 2,
181 ef_construction: self.config.ef_construction,
182 ef: self.config.ef_search,
183 ..HnswConfig::default()
184 };
185
186 let mut hnsw = HnswIndex::new_cpu_only(hnsw_config);
187
188 for (uri, vector) in &self.vectors {
189 hnsw.insert(uri.clone(), vector.clone())?;
190 }
191
192 self.hnsw_index = Some(hnsw);
193 }
194
195 Ok(())
196 }
197
198 fn build_ivf_index(&mut self) -> Result<()> {
204 let training_vectors: Vec<Vector> = self.vectors.iter().map(|(_, v)| v.clone()).collect();
205 let n_clusters = (self.vectors.len() / 4).clamp(2, 256);
206
207 let config = IvfConfig {
208 n_clusters,
209 n_probes: (n_clusters / 8).max(1),
210 ..Default::default()
211 };
212 let mut ivf = IvfIndex::new(config)?;
213 ivf.train(&training_vectors)?;
214
215 for (uri, vector) in &self.vectors {
216 ivf.insert(uri.clone(), vector.clone())?;
217 }
218
219 self.ivf_index = Some(ivf);
220 Ok(())
221 }
222
223 fn build_pq_index(&mut self) -> Result<()> {
229 let dims = self
230 .dimensions
231 .ok_or_else(|| anyhow!("Cannot build PQ index: no vectors have been inserted yet"))?;
232
233 let n_subquantizers = [8usize, 4, 2, 1]
235 .iter()
236 .copied()
237 .find(|&s| dims % s == 0)
238 .unwrap_or(1);
239
240 let config = PQConfig {
241 n_subquantizers,
242 n_centroids: 16, ..Default::default()
244 };
245 let mut pq = PQIndex::new(config);
246 let training_vectors: Vec<Vector> = self.vectors.iter().map(|(_, v)| v.clone()).collect();
247 pq.train(&training_vectors)?;
248
249 for (uri, vector) in &self.vectors {
250 pq.insert(uri.clone(), vector.clone())?;
251 }
252
253 self.pq_index = Some(pq);
254 Ok(())
255 }
256
257 pub fn add_metadata(&mut self, _uri: &str, _metadata: HashMap<String, String>) -> Result<()> {
259 Ok(())
262 }
263
264 pub fn search_advanced(
266 &self,
267 query: &Vector,
268 k: usize,
269 _ef: Option<usize>,
270 filter: Option<FilterFunction>,
271 ) -> Result<Vec<SearchResult>> {
272 match self.config.index_type {
273 IndexType::Hnsw => self.search_hnsw(query, k),
274 IndexType::Ivf => self.search_ivf(query, k),
275 IndexType::PQ => self.search_pq(query, k),
276 IndexType::Flat => self.search_flat(query, k, filter),
277 }
278 }
279
280 fn search_hnsw(&self, query: &Vector, k: usize) -> Result<Vec<SearchResult>> {
281 if let Some(ref hnsw) = self.hnsw_index {
282 let results = hnsw.search_knn(query, k)?;
283
284 Ok(results
285 .into_iter()
286 .map(|(uri, distance)| SearchResult {
287 uri,
288 distance,
289 score: 1.0 - distance,
290 metadata: None,
291 })
292 .collect())
293 } else {
294 Err(anyhow!("HNSW index not built"))
295 }
296 }
297
298 fn search_ivf(&self, query: &Vector, k: usize) -> Result<Vec<SearchResult>> {
299 let ivf = self
300 .ivf_index
301 .as_ref()
302 .ok_or_else(|| anyhow!("IVF index not built — call build() first"))?;
303 let results = ivf.search_knn(query, k)?;
304 Ok(results
305 .into_iter()
306 .map(|(uri, distance)| SearchResult {
307 uri,
308 score: 1.0 - distance,
309 distance,
310 metadata: None,
311 })
312 .collect())
313 }
314
315 fn search_pq(&self, query: &Vector, k: usize) -> Result<Vec<SearchResult>> {
316 let pq = self
317 .pq_index
318 .as_ref()
319 .ok_or_else(|| anyhow!("PQ index not built — call build() first"))?;
320 let results = pq.search_knn(query, k)?;
321 Ok(results
322 .into_iter()
323 .map(|(uri, distance)| SearchResult {
324 uri,
325 score: 1.0 - distance,
326 distance,
327 metadata: None,
328 })
329 .collect())
330 }
331
332 fn search_flat(
333 &self,
334 query: &Vector,
335 k: usize,
336 filter: Option<FilterFunction>,
337 ) -> Result<Vec<SearchResult>> {
338 if self.config.parallel && self.vectors.len() > 1000 {
339 if filter.is_some() {
341 self.search_flat_sequential(query, k, filter)
343 } else {
344 self.search_flat_parallel(query, k, None)
345 }
346 } else {
347 self.search_flat_sequential(query, k, filter)
348 }
349 }
350
351 fn search_flat_sequential(
352 &self,
353 query: &Vector,
354 k: usize,
355 filter: Option<FilterFunction>,
356 ) -> Result<Vec<SearchResult>> {
357 let mut heap = BinaryHeap::new();
358
359 for (uri, vector) in &self.vectors {
360 if let Some(ref filter_fn) = filter {
361 if !filter_fn(uri) {
362 continue;
363 }
364 }
365
366 let distance = self.config.distance_metric.distance_vectors(query, vector);
367
368 if heap.len() < k {
369 heap.push(std::cmp::Reverse(SearchResult {
370 uri: uri.clone(),
371 distance,
372 score: 1.0 - distance, metadata: None,
374 }));
375 } else if let Some(std::cmp::Reverse(worst)) = heap.peek() {
376 if distance < worst.distance {
377 heap.pop();
378 heap.push(std::cmp::Reverse(SearchResult {
379 uri: uri.clone(),
380 distance,
381 score: 1.0 - distance, metadata: None,
383 }));
384 }
385 }
386 }
387
388 let mut results: Vec<SearchResult> = heap.into_iter().map(|r| r.0).collect();
389 results.sort_by(|a, b| {
390 a.distance
391 .partial_cmp(&b.distance)
392 .unwrap_or(std::cmp::Ordering::Equal)
393 });
394
395 Ok(results)
396 }
397
398 fn search_flat_parallel(
399 &self,
400 query: &Vector,
401 k: usize,
402 filter: Option<FilterFunctionSync>,
403 ) -> Result<Vec<SearchResult>> {
404 let chunk_size = (self.vectors.len() / num_threads()).max(100);
406
407 let filter_arc = filter.map(Arc::new);
409
410 let partial_results: Vec<Vec<SearchResult>> = self
412 .vectors
413 .par_chunks(chunk_size)
414 .map(|chunk| {
415 let mut local_heap = BinaryHeap::new();
416 let filter_ref = filter_arc.as_ref();
417
418 for (uri, vector) in chunk {
419 if let Some(filter_fn) = filter_ref {
420 if !filter_fn(uri) {
421 continue;
422 }
423 }
424
425 let distance = self.config.distance_metric.distance_vectors(query, vector);
426
427 if local_heap.len() < k {
428 local_heap.push(std::cmp::Reverse(SearchResult {
429 uri: uri.clone(),
430 distance,
431 score: 1.0 - distance, metadata: None,
433 }));
434 } else if let Some(std::cmp::Reverse(worst)) = local_heap.peek() {
435 if distance < worst.distance {
436 local_heap.pop();
437 local_heap.push(std::cmp::Reverse(SearchResult {
438 uri: uri.clone(),
439 distance,
440 score: 1.0 - distance, metadata: None,
442 }));
443 }
444 }
445 }
446
447 local_heap
448 .into_sorted_vec()
449 .into_iter()
450 .map(|r| r.0)
451 .collect()
452 })
453 .collect();
454
455 let mut final_heap = BinaryHeap::new();
457 for partial in partial_results {
458 for result in partial {
459 if final_heap.len() < k {
460 final_heap.push(std::cmp::Reverse(result));
461 } else if let Some(std::cmp::Reverse(worst)) = final_heap.peek() {
462 if result.distance < worst.distance {
463 final_heap.pop();
464 final_heap.push(std::cmp::Reverse(result));
465 }
466 }
467 }
468 }
469
470 let mut results: Vec<SearchResult> = final_heap.into_iter().map(|r| r.0).collect();
471 results.sort_by(|a, b| {
472 a.distance
473 .partial_cmp(&b.distance)
474 .unwrap_or(std::cmp::Ordering::Equal)
475 });
476
477 Ok(results)
478 }
479
480 pub fn stats(&self) -> IndexStats {
482 IndexStats {
483 num_vectors: self.vectors.len(),
484 dimensions: self.dimensions.unwrap_or(0),
485 index_type: self.config.index_type,
486 memory_usage: self.estimate_memory_usage(),
487 }
488 }
489
490 fn estimate_memory_usage(&self) -> usize {
491 let vector_memory = self.vectors.len()
492 * (std::mem::size_of::<String>()
493 + self.dimensions.unwrap_or(0) * std::mem::size_of::<f32>());
494
495 let uri_map_memory =
496 self.uri_to_id.len() * (std::mem::size_of::<String>() + std::mem::size_of::<usize>());
497
498 vector_memory + uri_map_memory
499 }
500
501 pub fn len(&self) -> usize {
503 self.vectors.len()
504 }
505
506 pub fn is_empty(&self) -> bool {
508 self.vectors.is_empty()
509 }
510
511 pub fn add(
513 &mut self,
514 id: String,
515 vector: Vec<f32>,
516 _triple: Triple,
517 _metadata: HashMap<String, String>,
518 ) -> Result<()> {
519 let vector_obj = Vector::new(vector);
520 self.insert(id, vector_obj)
521 }
522
523 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
525 let query_vector = Vector::new(query.to_vec());
526 let results = self.search_advanced(&query_vector, k, None, None)?;
527 Ok(results)
528 }
529}
530
531impl VectorIndex for AdvancedVectorIndex {
532 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
533 if let Some(dims) = self.dimensions {
534 if vector.dimensions != dims {
535 return Err(anyhow!(
536 "Vector dimensions ({}) don't match index dimensions ({})",
537 vector.dimensions,
538 dims
539 ));
540 }
541 } else {
542 self.dimensions = Some(vector.dimensions);
543 }
544
545 let id = self.vectors.len();
546 self.uri_to_id.insert(uri.clone(), id);
547 self.vectors.push((uri, vector));
548
549 Ok(())
550 }
551
552 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
553 let results = self.search_advanced(query, k, None, None)?;
554 Ok(results.into_iter().map(|r| (r.uri, r.distance)).collect())
555 }
556
557 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
558 let mut results = Vec::new();
559
560 for (uri, vector) in &self.vectors {
561 let distance = self.config.distance_metric.distance_vectors(query, vector);
562 if distance <= threshold {
563 results.push((uri.clone(), distance));
564 }
565 }
566
567 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
568 Ok(results)
569 }
570
571 fn get_vector(&self, uri: &str) -> Option<&Vector> {
572 self.vectors.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
575 }
576}
577
578#[derive(Debug, Clone)]
580pub struct IndexStats {
581 pub num_vectors: usize,
582 pub dimensions: usize,
583 pub index_type: IndexType,
584 pub memory_usage: usize,
585}
586
587pub struct QuantizedVectorIndex {
589 config: IndexConfig,
590 quantized_vectors: Vec<Vec<u8>>,
591 centroids: Vec<Vector>,
592 uri_to_id: HashMap<String, usize>,
593 dimensions: Option<usize>,
594}
595
596impl QuantizedVectorIndex {
597 pub fn new(config: IndexConfig, num_centroids: usize) -> Self {
598 Self {
599 config,
600 quantized_vectors: Vec::new(),
601 centroids: Vec::with_capacity(num_centroids),
602 uri_to_id: HashMap::new(),
603 dimensions: None,
604 }
605 }
606
607 pub fn train_quantization(&mut self, training_vectors: &[Vector]) -> Result<()> {
609 if training_vectors.is_empty() {
610 return Err(anyhow!("No training vectors provided"));
611 }
612
613 let dimensions = training_vectors[0].dimensions;
614 self.dimensions = Some(dimensions);
615
616 self.centroids = kmeans_clustering(training_vectors, self.centroids.capacity())?;
618
619 Ok(())
620 }
621
622 fn quantize_vector(&self, vector: &Vector) -> Vec<u8> {
623 let mut quantized = Vec::new();
624
625 let chunk_size = vector.dimensions / self.centroids.len().max(1);
627
628 let vector_f32 = vector.as_f32();
629 for chunk in vector_f32.chunks(chunk_size) {
630 let mut best_centroid = 0u8;
631 let mut best_distance = f32::INFINITY;
632
633 for (i, centroid) in self.centroids.iter().enumerate() {
634 let centroid_f32 = centroid.as_f32();
635 let centroid_chunk = ¢roid_f32[0..chunk.len().min(centroid.dimensions)];
636 use oxirs_core::simd::SimdOps;
637 let distance = f32::euclidean_distance(chunk, centroid_chunk);
638 if distance < best_distance {
639 best_distance = distance;
640 best_centroid = i as u8;
641 }
642 }
643
644 quantized.push(best_centroid);
645 }
646
647 quantized
648 }
649}
650
651impl VectorIndex for QuantizedVectorIndex {
652 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
653 if self.centroids.is_empty() {
654 return Err(anyhow!(
655 "Quantization not trained. Call train_quantization first."
656 ));
657 }
658
659 let id = self.quantized_vectors.len();
660 self.uri_to_id.insert(uri.clone(), id);
661
662 let quantized = self.quantize_vector(&vector);
663 self.quantized_vectors.push(quantized);
664
665 Ok(())
666 }
667
668 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
669 let query_quantized = self.quantize_vector(query);
670 let mut results = Vec::new();
671
672 for (uri, quantized) in self.uri_to_id.keys().zip(&self.quantized_vectors) {
673 let distance = hamming_distance(&query_quantized, quantized);
674 results.push((uri.clone(), distance));
675 }
676
677 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
678 results.truncate(k);
679
680 Ok(results)
681 }
682
683 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
684 let query_quantized = self.quantize_vector(query);
685 let mut results = Vec::new();
686
687 for (uri, quantized) in self.uri_to_id.keys().zip(&self.quantized_vectors) {
688 let distance = hamming_distance(&query_quantized, quantized);
689 if distance <= threshold {
690 results.push((uri.clone(), distance));
691 }
692 }
693
694 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
695 Ok(results)
696 }
697
698 fn get_vector(&self, _uri: &str) -> Option<&Vector> {
699 None
702 }
703}
704
705fn hamming_distance(a: &[u8], b: &[u8]) -> f32 {
708 a.iter().zip(b).filter(|(x, y)| x != y).count() as f32
709}
710
711fn kmeans_clustering(vectors: &[Vector], k: usize) -> Result<Vec<Vector>> {
713 if vectors.is_empty() || k == 0 {
714 return Ok(Vec::new());
715 }
716
717 let dimensions = vectors[0].dimensions;
718 let mut centroids = Vec::with_capacity(k);
719
720 for i in 0..k {
722 let idx = i % vectors.len();
723 centroids.push(vectors[idx].clone());
724 }
725
726 for _ in 0..10 {
728 let mut clusters: Vec<Vec<&Vector>> = vec![Vec::new(); k];
729
730 for vector in vectors {
732 let mut best_centroid = 0;
733 let mut best_distance = f32::INFINITY;
734
735 for (i, centroid) in centroids.iter().enumerate() {
736 let vector_f32 = vector.as_f32();
737 let centroid_f32 = centroid.as_f32();
738 use oxirs_core::simd::SimdOps;
739 let distance = f32::euclidean_distance(&vector_f32, ¢roid_f32);
740 if distance < best_distance {
741 best_distance = distance;
742 best_centroid = i;
743 }
744 }
745
746 clusters[best_centroid].push(vector);
747 }
748
749 for (i, cluster) in clusters.iter().enumerate() {
751 if !cluster.is_empty() {
752 let mut new_centroid = vec![0.0; dimensions];
753
754 for vector in cluster {
755 let vector_f32 = vector.as_f32();
756 for (j, &value) in vector_f32.iter().enumerate() {
757 new_centroid[j] += value;
758 }
759 }
760
761 for value in &mut new_centroid {
762 *value /= cluster.len() as f32;
763 }
764
765 centroids[i] = Vector::new(new_centroid);
766 }
767 }
768 }
769
770 Ok(centroids)
771}
772
773pub struct MultiIndex {
775 indices: HashMap<String, Box<dyn VectorIndex>>,
776 default_index: String,
777}
778
779impl MultiIndex {
780 pub fn new() -> Self {
781 Self {
782 indices: HashMap::new(),
783 default_index: String::new(),
784 }
785 }
786
787 pub fn add_index(&mut self, name: String, index: Box<dyn VectorIndex>) {
788 if self.indices.is_empty() {
789 self.default_index = name.clone();
790 }
791 self.indices.insert(name, index);
792 }
793
794 pub fn set_default(&mut self, name: &str) -> Result<()> {
795 if self.indices.contains_key(name) {
796 self.default_index = name.to_string();
797 Ok(())
798 } else {
799 Err(anyhow!("Index '{}' not found", name))
800 }
801 }
802
803 pub fn search_index(
804 &self,
805 index_name: &str,
806 query: &Vector,
807 k: usize,
808 ) -> Result<Vec<(String, f32)>> {
809 if let Some(index) = self.indices.get(index_name) {
810 index.search_knn(query, k)
811 } else {
812 Err(anyhow!("Index '{}' not found", index_name))
813 }
814 }
815}
816
817impl Default for MultiIndex {
818 fn default() -> Self {
819 Self::new()
820 }
821}
822
823impl VectorIndex for MultiIndex {
824 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
825 if let Some(index) = self.indices.get_mut(&self.default_index) {
826 index.insert(uri, vector)
827 } else {
828 Err(anyhow!("No default index set"))
829 }
830 }
831
832 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
833 if let Some(index) = self.indices.get(&self.default_index) {
834 index.search_knn(query, k)
835 } else {
836 Err(anyhow!("No default index set"))
837 }
838 }
839
840 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
841 if let Some(index) = self.indices.get(&self.default_index) {
842 index.search_threshold(query, threshold)
843 } else {
844 Err(anyhow!("No default index set"))
845 }
846 }
847
848 fn get_vector(&self, uri: &str) -> Option<&Vector> {
849 if let Some(index) = self.indices.get(&self.default_index) {
850 index.get_vector(uri)
851 } else {
852 None
853 }
854 }
855}
856
857#[cfg(test)]
858mod tests {
859 use super::*;
860
861 fn sample_vectors() -> Vec<(&'static str, Vector)> {
862 vec![
863 (
864 "http://example.org/a",
865 Vector::new(vec![1.0, 0.0, 0.0, 1.0]),
866 ),
867 (
868 "http://example.org/b",
869 Vector::new(vec![0.0, 1.0, 1.0, 0.0]),
870 ),
871 (
872 "http://example.org/c",
873 Vector::new(vec![-1.0, 0.0, 0.0, -1.0]),
874 ),
875 (
876 "http://example.org/d",
877 Vector::new(vec![0.0, -1.0, -1.0, 0.0]),
878 ),
879 (
880 "http://example.org/e",
881 Vector::new(vec![0.5, 0.5, 0.5, 0.5]),
882 ),
883 (
884 "http://example.org/f",
885 Vector::new(vec![-0.5, 0.5, -0.5, 0.5]),
886 ),
887 (
888 "http://example.org/g",
889 Vector::new(vec![1.0, 1.0, 0.0, 0.0]),
890 ),
891 (
892 "http://example.org/h",
893 Vector::new(vec![0.0, 0.0, 1.0, 1.0]),
894 ),
895 ]
896 }
897
898 fn build_index(index_type: IndexType) -> Result<AdvancedVectorIndex> {
899 let config = IndexConfig {
900 index_type,
901 ..Default::default()
902 };
903 let mut idx = AdvancedVectorIndex::new(config);
904 for (uri, vec) in sample_vectors() {
905 idx.insert(uri.to_string(), vec)?;
906 }
907 idx.build()?;
908 Ok(idx)
909 }
910
911 #[test]
912 fn test_ivf_build_and_search() -> Result<()> {
913 let idx = build_index(IndexType::Ivf)?;
914 assert!(idx.ivf_index.is_some(), "IVF index should be built");
915
916 let query = Vector::new(vec![1.0, 0.0, 0.0, 1.0]);
917 let results = idx.search(&query.as_f32(), 3)?;
918 assert!(!results.is_empty(), "IVF search should return results");
919 Ok(())
920 }
921
922 #[test]
923 fn test_pq_build_and_search() -> Result<()> {
924 let idx = build_index(IndexType::PQ)?;
925 assert!(idx.pq_index.is_some(), "PQ index should be built");
926
927 let query = Vector::new(vec![0.0, 1.0, 1.0, 0.0]);
928 let results = idx.search(&query.as_f32(), 3)?;
929 assert!(!results.is_empty(), "PQ search should return results");
930 Ok(())
931 }
932
933 #[test]
934 fn test_flat_search_unchanged() -> Result<()> {
935 let idx = build_index(IndexType::Flat)?;
936 let query = Vector::new(vec![1.0, 0.0, 0.0, 1.0]);
937 let results = idx.search(&query.as_f32(), 2)?;
938 assert_eq!(
939 results.len(),
940 2,
941 "Flat search should return exactly k results"
942 );
943 Ok(())
944 }
945}