1use crate::index::VectorIndex;
8use anyhow::{Context, Error as AnyhowError, Result};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::{Path, PathBuf};
12use std::sync::{Arc, Mutex, RwLock};
13use tracing::{debug, info, span, Level};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct FaissConfig {
18 pub index_type: FaissIndexType,
20 pub dimension: usize,
22 pub training_sample_size: usize,
24 pub num_clusters: Option<usize>,
26 pub num_subquantizers: Option<usize>,
28 pub bits_per_subquantizer: Option<u8>,
30 pub use_gpu: bool,
32 pub gpu_devices: Vec<u32>,
34 pub enable_mmap: bool,
36 pub persistence: FaissPersistenceConfig,
38 pub optimization: FaissOptimizationConfig,
40}
41
42impl Default for FaissConfig {
43 fn default() -> Self {
44 Self {
45 index_type: FaissIndexType::FlatL2,
46 dimension: 384,
47 training_sample_size: 10000,
48 num_clusters: Some(1024),
49 num_subquantizers: Some(8),
50 bits_per_subquantizer: Some(8),
51 use_gpu: false,
52 gpu_devices: vec![0],
53 enable_mmap: true,
54 persistence: FaissPersistenceConfig::default(),
55 optimization: FaissOptimizationConfig::default(),
56 }
57 }
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub enum FaissIndexType {
63 FlatL2,
65 FlatIP,
67 IvfFlat,
69 IvfPq,
71 IvfSq,
73 HnswFlat,
75 Lsh,
77 Auto,
79 Custom(String),
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct FaissPersistenceConfig {
86 pub index_directory: PathBuf,
88 pub auto_save: bool,
90 pub save_interval: u64,
92 pub compression: bool,
94 pub backup: FaissBackupConfig,
96}
97
98impl Default for FaissPersistenceConfig {
99 fn default() -> Self {
100 Self {
101 index_directory: PathBuf::from("./faiss_indices"),
102 auto_save: true,
103 save_interval: 300, compression: true,
105 backup: FaissBackupConfig::default(),
106 }
107 }
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct FaissBackupConfig {
113 pub enabled: bool,
115 pub backup_directory: PathBuf,
117 pub max_versions: usize,
119 pub backup_frequency: u64,
121}
122
123impl Default for FaissBackupConfig {
124 fn default() -> Self {
125 Self {
126 enabled: true,
127 backup_directory: PathBuf::from("./faiss_backups"),
128 max_versions: 5,
129 backup_frequency: 3600, }
131 }
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct FaissOptimizationConfig {
137 pub auto_optimize: bool,
139 pub optimization_frequency: usize,
141 pub dynamic_tuning: bool,
143 pub monitoring: FaissMonitoringConfig,
145}
146
147impl Default for FaissOptimizationConfig {
148 fn default() -> Self {
149 Self {
150 auto_optimize: true,
151 optimization_frequency: 100000,
152 dynamic_tuning: true,
153 monitoring: FaissMonitoringConfig::default(),
154 }
155 }
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct FaissMonitoringConfig {
161 pub enabled: bool,
163 pub collection_interval: u64,
165 pub track_memory: bool,
167 pub track_queries: bool,
169}
170
171impl Default for FaissMonitoringConfig {
172 fn default() -> Self {
173 Self {
174 enabled: true,
175 collection_interval: 60,
176 track_memory: true,
177 track_queries: true,
178 }
179 }
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct FaissSearchParams {
185 pub k: usize,
187 pub nprobe: Option<usize>,
189 pub hnsw_ef: Option<usize>,
191 pub exact_search: bool,
193 pub timeout_ms: Option<u64>,
195}
196
197impl Default for FaissSearchParams {
198 fn default() -> Self {
199 Self {
200 k: 10,
201 nprobe: Some(64),
202 hnsw_ef: Some(128),
203 exact_search: false,
204 timeout_ms: Some(5000),
205 }
206 }
207}
208
209pub struct FaissIndex {
211 config: FaissConfig,
213 index_handle: Arc<Mutex<Option<FaissIndexHandle>>>,
215 vectors: Arc<RwLock<Vec<Vec<f32>>>>,
217 metadata: Arc<RwLock<HashMap<usize, VectorMetadata>>>,
219 stats: Arc<RwLock<FaissStatistics>>,
221 training_state: Arc<RwLock<TrainingState>>,
223}
224
225#[derive(Debug)]
227pub struct FaissIndexHandle {
228 pub index_type: String,
230 pub num_vectors: usize,
232 pub dimension: usize,
234 pub is_trained: bool,
236 pub gpu_device: Option<u32>,
238}
239
240#[derive(Debug, Clone, Serialize, Deserialize)]
242pub struct VectorMetadata {
243 pub id: String,
245 pub timestamp: std::time::SystemTime,
247 pub norm: f32,
249 pub attributes: HashMap<String, String>,
251}
252
253#[derive(Debug, Clone)]
255pub struct TrainingState {
256 pub is_trained: bool,
258 pub training_progress: f32,
260 pub training_start: Option<std::time::Instant>,
262 pub training_vectors_count: usize,
264}
265
266impl Default for TrainingState {
267 fn default() -> Self {
268 Self {
269 is_trained: false,
270 training_progress: 0.0,
271 training_start: None,
272 training_vectors_count: 0,
273 }
274 }
275}
276
277#[derive(Debug, Clone, Default, Serialize, Deserialize)]
279pub struct FaissStatistics {
280 pub total_vectors: usize,
282 pub total_searches: usize,
284 pub avg_search_time_us: f64,
286 pub memory_usage_bytes: usize,
288 pub gpu_memory_usage_bytes: Option<usize>,
290 pub index_build_time_s: f64,
292 pub last_optimization: Option<std::time::SystemTime>,
294 pub performance_history: Vec<PerformanceSnapshot>,
296}
297
298#[derive(Debug, Clone, Serialize, Deserialize)]
300pub struct PerformanceSnapshot {
301 pub timestamp: std::time::SystemTime,
303 pub search_latency_p50: f64,
305 pub search_latency_p95: f64,
306 pub search_latency_p99: f64,
307 pub throughput_qps: f64,
309 pub memory_usage_mb: f64,
311 pub gpu_utilization: Option<f32>,
313}
314
315impl FaissIndex {
316 pub fn new(config: FaissConfig) -> Result<Self> {
318 let span = span!(Level::INFO, "faiss_index_new");
319 let _enter = span.enter();
320
321 Self::validate_config(&config)?;
323
324 let index = Self {
325 config: config.clone(),
326 index_handle: Arc::new(Mutex::new(None)),
327 vectors: Arc::new(RwLock::new(Vec::new())),
328 metadata: Arc::new(RwLock::new(HashMap::new())),
329 stats: Arc::new(RwLock::new(FaissStatistics::default())),
330 training_state: Arc::new(RwLock::new(TrainingState::default())),
331 };
332
333 index.initialize_faiss_index()?;
335
336 info!(
337 "Created FAISS index with type {:?}, dimension {}",
338 config.index_type, config.dimension
339 );
340
341 Ok(index)
342 }
343
344 fn validate_config(config: &FaissConfig) -> Result<()> {
346 if config.dimension == 0 {
347 return Err(AnyhowError::msg("Dimension must be greater than 0"));
348 }
349
350 if config.training_sample_size == 0 {
351 return Err(AnyhowError::msg(
352 "Training sample size must be greater than 0",
353 ));
354 }
355
356 match &config.index_type {
358 FaissIndexType::IvfFlat | FaissIndexType::IvfSq => {
359 if config.num_clusters.is_none() {
360 return Err(AnyhowError::msg(
361 "IVF indices require num_clusters to be set",
362 ));
363 }
364 }
365 FaissIndexType::IvfPq => {
366 if config.num_clusters.is_none() {
367 return Err(AnyhowError::msg(
368 "IVF indices require num_clusters to be set",
369 ));
370 }
371 if config.num_subquantizers.is_none() {
372 return Err(AnyhowError::msg(
373 "IVF-PQ requires num_subquantizers to be set",
374 ));
375 }
376 if config.bits_per_subquantizer.is_none() {
377 return Err(AnyhowError::msg(
378 "IVF-PQ requires bits_per_subquantizer to be set",
379 ));
380 }
381 }
382 _ => {}
383 }
384
385 Ok(())
386 }
387
388 fn initialize_faiss_index(&self) -> Result<()> {
390 let span = span!(Level::DEBUG, "initialize_faiss_index");
391 let _enter = span.enter();
392
393 let index_type_str = self.faiss_index_string()?;
395
396 let handle = FaissIndexHandle {
397 index_type: index_type_str,
398 num_vectors: 0,
399 dimension: self.config.dimension,
400 is_trained: self.requires_training(),
401 gpu_device: if self.config.use_gpu {
402 Some(self.config.gpu_devices.first().copied().unwrap_or(0))
403 } else {
404 None
405 },
406 };
407
408 let mut index_handle = self
409 .index_handle
410 .lock()
411 .map_err(|_| AnyhowError::msg("Failed to acquire index handle lock"))?;
412 *index_handle = Some(handle);
413
414 debug!("Initialized FAISS index: {}", self.faiss_index_string()?);
415 Ok(())
416 }
417
418 fn requires_training(&self) -> bool {
420 !matches!(
421 self.config.index_type,
422 FaissIndexType::FlatL2 | FaissIndexType::FlatIP
423 )
424 }
425
426 fn faiss_index_string(&self) -> Result<String> {
428 let index_str = match &self.config.index_type {
429 FaissIndexType::FlatL2 => "Flat".to_string(),
430 FaissIndexType::FlatIP => "Flat".to_string(),
431 FaissIndexType::IvfFlat => {
432 let clusters = self.config.num_clusters.unwrap_or(1024);
433 format!("IVF{clusters},Flat")
434 }
435 FaissIndexType::IvfPq => {
436 let clusters = self.config.num_clusters.unwrap_or(1024);
437 let subq = self.config.num_subquantizers.unwrap_or(8);
438 let bits = self.config.bits_per_subquantizer.unwrap_or(8);
439 format!("IVF{clusters},PQ{subq}x{bits}")
440 }
441 FaissIndexType::IvfSq => {
442 let clusters = self.config.num_clusters.unwrap_or(1024);
443 format!("IVF{clusters},SQ8")
444 }
445 FaissIndexType::HnswFlat => "HNSW32,Flat".to_string(),
446 FaissIndexType::Lsh => "LSH".to_string(),
447 FaissIndexType::Auto => self.auto_select_index_type()?,
448 FaissIndexType::Custom(s) => s.clone(),
449 };
450
451 Ok(index_str)
452 }
453
454 fn auto_select_index_type(&self) -> Result<String> {
456 let num_vectors = {
457 let vectors = self
458 .vectors
459 .read()
460 .map_err(|_| AnyhowError::msg("Failed to acquire vectors lock"))?;
461 vectors.len()
462 };
463
464 let dimension = self.config.dimension;
465
466 let index_str = if num_vectors < 10000 {
468 "Flat".to_string()
470 } else if num_vectors < 1000000 {
471 let clusters = (num_vectors as f32).sqrt() as usize;
473 if dimension > 128 {
474 format!("IVF{clusters},PQ16x8")
475 } else {
476 format!("IVF{clusters},Flat")
477 }
478 } else {
479 let clusters = (num_vectors as f32).sqrt() as usize;
481 format!("IVF{},PQ{}x8", clusters, std::cmp::min(dimension / 4, 64))
482 };
483
484 debug!(
485 "Auto-selected FAISS index: {} for {} vectors, {} dimensions",
486 index_str, num_vectors, dimension
487 );
488
489 Ok(index_str)
490 }
491
492 pub fn train(&self, training_vectors: &[Vec<f32>]) -> Result<()> {
494 let span = span!(Level::INFO, "faiss_train");
495 let _enter = span.enter();
496
497 if !self.requires_training() {
498 debug!("Index type does not require training");
499 return Ok(());
500 }
501
502 {
504 let mut state = self
505 .training_state
506 .write()
507 .map_err(|_| AnyhowError::msg("Failed to acquire training state lock"))?;
508 state.training_start = Some(std::time::Instant::now());
509 state.training_vectors_count = training_vectors.len();
510 state.training_progress = 0.0;
511 }
512
513 if training_vectors.is_empty() {
515 return Err(AnyhowError::msg("Training vectors cannot be empty"));
516 }
517
518 for (i, vector) in training_vectors.iter().enumerate() {
519 if vector.len() != self.config.dimension {
520 return Err(AnyhowError::msg(format!(
521 "Training vector {} has dimension {}, expected {}",
522 i,
523 vector.len(),
524 self.config.dimension
525 )));
526 }
527 }
528
529 info!(
531 "Training FAISS index with {} vectors",
532 training_vectors.len()
533 );
534
535 for progress in 0..=10 {
537 std::thread::sleep(std::time::Duration::from_millis(100));
538 let mut state = self
539 .training_state
540 .write()
541 .map_err(|_| AnyhowError::msg("Failed to acquire training state lock"))?;
542 state.training_progress = progress as f32 / 10.0;
543 }
544
545 {
547 let mut state = self
548 .training_state
549 .write()
550 .map_err(|_| AnyhowError::msg("Failed to acquire training state lock"))?;
551 state.is_trained = true;
552 state.training_progress = 1.0;
553 }
554
555 {
557 let mut handle = self
558 .index_handle
559 .lock()
560 .map_err(|_| AnyhowError::msg("Failed to acquire index handle lock"))?;
561 if let Some(ref mut h) = *handle {
562 h.is_trained = true;
563 }
564 }
565
566 info!("FAISS index training completed successfully");
567 Ok(())
568 }
569
570 pub fn add_vectors(&self, vectors: Vec<Vec<f32>>, ids: Vec<String>) -> Result<()> {
572 let span = span!(Level::DEBUG, "faiss_add_vectors");
573 let _enter = span.enter();
574
575 if vectors.len() != ids.len() {
576 return Err(AnyhowError::msg(
577 "Number of vectors must match number of IDs",
578 ));
579 }
580
581 if self.requires_training() {
583 let state = self
584 .training_state
585 .read()
586 .map_err(|_| AnyhowError::msg("Failed to acquire training state lock"))?;
587 if !state.is_trained {
588 return Err(AnyhowError::msg(
589 "Index must be trained before adding vectors",
590 ));
591 }
592 }
593
594 for (i, vector) in vectors.iter().enumerate() {
596 if vector.len() != self.config.dimension {
597 return Err(AnyhowError::msg(format!(
598 "Vector {} has dimension {}, expected {}",
599 i,
600 vector.len(),
601 self.config.dimension
602 )));
603 }
604 }
605
606 let start_time = std::time::Instant::now();
607
608 let mut vec_storage = self
610 .vectors
611 .write()
612 .map_err(|_| AnyhowError::msg("Failed to acquire vectors lock"))?;
613 let mut metadata_storage = self
614 .metadata
615 .write()
616 .map_err(|_| AnyhowError::msg("Failed to acquire metadata lock"))?;
617
618 for (vector, id) in vectors.iter().zip(ids.iter()) {
619 let index = vec_storage.len();
620 vec_storage.push(vector.clone());
621
622 let norm = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
623 let metadata = VectorMetadata {
624 id: id.clone(),
625 timestamp: std::time::SystemTime::now(),
626 norm,
627 attributes: HashMap::new(),
628 };
629 metadata_storage.insert(index, metadata);
630 }
631
632 {
634 let mut stats = self
635 .stats
636 .write()
637 .map_err(|_| AnyhowError::msg("Failed to acquire stats lock"))?;
638 stats.total_vectors += vectors.len();
639 stats.index_build_time_s += start_time.elapsed().as_secs_f64();
640 }
641
642 {
644 let mut handle = self
645 .index_handle
646 .lock()
647 .map_err(|_| AnyhowError::msg("Failed to acquire index handle lock"))?;
648 if let Some(ref mut h) = *handle {
649 h.num_vectors += vectors.len();
650 }
651 }
652
653 debug!("Added {} vectors to FAISS index", vectors.len());
654 Ok(())
655 }
656
657 pub fn search(
659 &self,
660 query_vector: &[f32],
661 params: &FaissSearchParams,
662 ) -> Result<Vec<(String, f32)>> {
663 let span = span!(Level::DEBUG, "faiss_search");
664 let _enter = span.enter();
665
666 if query_vector.len() != self.config.dimension {
667 return Err(AnyhowError::msg(format!(
668 "Query vector has dimension {}, expected {}",
669 query_vector.len(),
670 self.config.dimension
671 )));
672 }
673
674 let start_time = std::time::Instant::now();
675
676 let results = self.simulate_search(query_vector, params)?;
678
679 {
681 let mut stats = self
682 .stats
683 .write()
684 .map_err(|_| AnyhowError::msg("Failed to acquire stats lock"))?;
685 stats.total_searches += 1;
686 let search_time_us = start_time.elapsed().as_micros() as f64;
687 stats.avg_search_time_us =
688 (stats.avg_search_time_us * (stats.total_searches - 1) as f64 + search_time_us)
689 / stats.total_searches as f64;
690 }
691
692 debug!("FAISS search completed in {:?}", start_time.elapsed());
693 Ok(results)
694 }
695
696 fn simulate_search(
698 &self,
699 query_vector: &[f32],
700 params: &FaissSearchParams,
701 ) -> Result<Vec<(String, f32)>> {
702 let vectors = self
703 .vectors
704 .read()
705 .map_err(|_| AnyhowError::msg("Failed to acquire vectors lock"))?;
706 let metadata = self
707 .metadata
708 .read()
709 .map_err(|_| AnyhowError::msg("Failed to acquire metadata lock"))?;
710
711 let mut results = Vec::new();
712
713 for (i, vector) in vectors.iter().enumerate() {
715 let distance = self.compute_distance(query_vector, vector);
716 if let Some(meta) = metadata.get(&i) {
717 results.push((meta.id.clone(), distance));
718 }
719 }
720
721 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
723 results.truncate(params.k);
724
725 Ok(results)
726 }
727
728 fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
730 match self.config.index_type {
731 FaissIndexType::FlatL2
732 | FaissIndexType::IvfFlat
733 | FaissIndexType::IvfPq
734 | FaissIndexType::IvfSq => {
735 a.iter()
737 .zip(b.iter())
738 .map(|(x, y)| (x - y).powi(2))
739 .sum::<f32>()
740 .sqrt()
741 }
742 FaissIndexType::FlatIP => {
743 -a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>()
745 }
746 _ => {
747 a.iter()
749 .zip(b.iter())
750 .map(|(x, y)| (x - y).powi(2))
751 .sum::<f32>()
752 .sqrt()
753 }
754 }
755 }
756
757 pub fn get_statistics(&self) -> Result<FaissStatistics> {
759 let stats = self
760 .stats
761 .read()
762 .map_err(|_| AnyhowError::msg("Failed to acquire stats lock"))?;
763 Ok(stats.clone())
764 }
765
766 pub fn save_index(&self, path: &Path) -> Result<()> {
768 let span = span!(Level::INFO, "faiss_save_index");
769 let _enter = span.enter();
770
771 if let Some(parent) = path.parent() {
773 std::fs::create_dir_all(parent)
774 .with_context(|| format!("Failed to create directory: {parent:?}"))?;
775 }
776
777 info!("Saving FAISS index to {:?}", path);
779
780 std::thread::sleep(std::time::Duration::from_millis(100));
782
783 Ok(())
784 }
785
786 pub fn load_index(&self, path: &Path) -> Result<()> {
788 let span = span!(Level::INFO, "faiss_load_index");
789 let _enter = span.enter();
790
791 if !path.exists() {
792 return Err(AnyhowError::msg(format!(
793 "Index file does not exist: {path:?}"
794 )));
795 }
796
797 info!("Loading FAISS index from {:?}", path);
799
800 std::thread::sleep(std::time::Duration::from_millis(100));
802
803 Ok(())
804 }
805
806 pub fn optimize(&self) -> Result<()> {
808 let span = span!(Level::INFO, "faiss_optimize");
809 let _enter = span.enter();
810
811 {
813 let mut stats = self
814 .stats
815 .write()
816 .map_err(|_| AnyhowError::msg("Failed to acquire stats lock"))?;
817 stats.last_optimization = Some(std::time::SystemTime::now());
818 }
819
820 info!("FAISS index optimization completed");
821 Ok(())
822 }
823
824 pub fn get_memory_usage(&self) -> Result<usize> {
826 let vectors = self
827 .vectors
828 .read()
829 .map_err(|_| AnyhowError::msg("Failed to acquire vectors lock"))?;
830
831 let vector_memory = vectors.len() * self.config.dimension * std::mem::size_of::<f32>();
832 let metadata_memory = vectors.len() * std::mem::size_of::<VectorMetadata>();
833
834 Ok(vector_memory + metadata_memory)
835 }
836
837 pub fn dimension(&self) -> usize {
839 self.config.dimension
840 }
841
842 pub fn size(&self) -> usize {
844 self.vectors.read().map(|v| v.len()).unwrap_or(0)
845 }
846}
847
848impl VectorIndex for FaissIndex {
849 fn insert(&mut self, uri: String, vector: crate::Vector) -> Result<()> {
850 self.add_vectors(vec![vector.as_f32()], vec![uri])
851 }
852
853 fn search_knn(&self, query: &crate::Vector, k: usize) -> Result<Vec<(String, f32)>> {
854 let params = FaissSearchParams {
855 k,
856 ..Default::default()
857 };
858 self.search(&query.as_f32(), ¶ms)
859 }
860
861 fn search_threshold(
862 &self,
863 query: &crate::Vector,
864 threshold: f32,
865 ) -> Result<Vec<(String, f32)>> {
866 let params = FaissSearchParams {
867 k: 1000, ..Default::default()
870 };
871 let results = self.search(&query.as_f32(), ¶ms)?;
872 Ok(results
873 .into_iter()
874 .filter(|(_, score)| *score >= threshold)
875 .collect())
876 }
877
878 fn get_vector(&self, _uri: &str) -> Option<&crate::Vector> {
879 None
882 }
883}
884
885pub struct FaissFactory;
887
888impl FaissFactory {
889 pub fn create_optimized_index(
891 dimension: usize,
892 expected_size: usize,
893 use_gpu: bool,
894 ) -> Result<FaissIndex> {
895 let index_type = if expected_size < 10000 {
896 FaissIndexType::FlatL2
897 } else if expected_size < 1000000 {
898 FaissIndexType::IvfFlat
899 } else {
900 FaissIndexType::IvfPq
901 };
902
903 let config = FaissConfig {
904 index_type,
905 dimension,
906 training_sample_size: std::cmp::min(expected_size / 10, 100000),
907 num_clusters: Some((expected_size as f32).sqrt() as usize),
908 use_gpu,
909 ..Default::default()
910 };
911
912 FaissIndex::new(config)
913 }
914
915 pub fn create_gpu_index(dimension: usize, gpu_devices: Vec<u32>) -> Result<FaissIndex> {
917 let config = FaissConfig {
918 dimension,
919 use_gpu: true,
920 gpu_devices,
921 index_type: FaissIndexType::Auto,
922 ..Default::default()
923 };
924
925 FaissIndex::new(config)
926 }
927}
928
929#[cfg(test)]
930mod tests {
931 use super::*;
932
933 #[test]
934 fn test_faiss_index_creation() {
935 let config = FaissConfig {
936 dimension: 128,
937 index_type: FaissIndexType::FlatL2,
938 ..Default::default()
939 };
940
941 let index = FaissIndex::new(config).unwrap();
942 assert_eq!(index.dimension(), 128);
943 assert_eq!(index.size(), 0);
944 }
945
946 #[test]
947 fn test_faiss_add_and_search() {
948 let config = FaissConfig {
949 dimension: 4,
950 index_type: FaissIndexType::FlatL2,
951 ..Default::default()
952 };
953
954 let index = FaissIndex::new(config).unwrap();
955
956 let vectors = vec![
958 vec![1.0, 0.0, 0.0, 0.0],
959 vec![0.0, 1.0, 0.0, 0.0],
960 vec![0.0, 0.0, 1.0, 0.0],
961 ];
962 let ids = vec!["vec1".to_string(), "vec2".to_string(), "vec3".to_string()];
963
964 index.add_vectors(vectors, ids).unwrap();
965 assert_eq!(index.size(), 3);
966
967 let query = vec![1.0, 0.1, 0.0, 0.0];
969 let params = FaissSearchParams {
970 k: 2,
971 ..Default::default()
972 };
973 let results = index.search(&query, ¶ms).unwrap();
974
975 assert_eq!(results.len(), 2);
976 assert_eq!(results[0].0, "vec1"); }
978
979 #[test]
980 fn test_faiss_training() {
981 let config = FaissConfig {
982 dimension: 4,
983 index_type: FaissIndexType::IvfFlat,
984 num_clusters: Some(2),
985 training_sample_size: 10,
986 ..Default::default()
987 };
988
989 let index = FaissIndex::new(config).unwrap();
990
991 let training_vectors: Vec<Vec<f32>> = (0..10)
993 .map(|i| vec![i as f32, (i % 2) as f32, 0.0, 0.0])
994 .collect();
995
996 index.train(&training_vectors).unwrap();
997
998 let state = index.training_state.read().unwrap();
999 assert!(state.is_trained);
1000 assert_eq!(state.training_progress, 1.0);
1001 }
1002
1003 #[test]
1004 fn test_faiss_factory() {
1005 let index = FaissFactory::create_optimized_index(64, 1000, false).unwrap();
1006 assert_eq!(index.dimension(), 64);
1007
1008 let gpu_index = FaissFactory::create_gpu_index(128, vec![0]).unwrap();
1009 assert_eq!(gpu_index.dimension(), 128);
1010 assert!(gpu_index.config.use_gpu);
1011 }
1012
1013 #[test]
1014 fn test_faiss_auto_index_selection() {
1015 let config = FaissConfig {
1016 dimension: 64,
1017 index_type: FaissIndexType::Auto,
1018 ..Default::default()
1019 };
1020
1021 let index = FaissIndex::new(config).unwrap();
1022 let index_str = index.faiss_index_string().unwrap();
1023
1024 assert_eq!(index_str, "Flat");
1026 }
1027}