1use crate::index::VectorIndex;
8use anyhow::{Context, Error as AnyhowError, Result};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::{Path, PathBuf};
12use std::sync::{Arc, Mutex, RwLock};
13use tracing::{debug, info, span, Level};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct FaissConfig {
18 pub index_type: FaissIndexType,
20 pub dimension: usize,
22 pub training_sample_size: usize,
24 pub num_clusters: Option<usize>,
26 pub num_subquantizers: Option<usize>,
28 pub bits_per_subquantizer: Option<u8>,
30 pub use_gpu: bool,
32 pub gpu_devices: Vec<u32>,
34 pub enable_mmap: bool,
36 pub persistence: FaissPersistenceConfig,
38 pub optimization: FaissOptimizationConfig,
40}
41
42impl Default for FaissConfig {
43 fn default() -> Self {
44 Self {
45 index_type: FaissIndexType::FlatL2,
46 dimension: 384,
47 training_sample_size: 10000,
48 num_clusters: Some(1024),
49 num_subquantizers: Some(8),
50 bits_per_subquantizer: Some(8),
51 use_gpu: false,
52 gpu_devices: vec![0],
53 enable_mmap: true,
54 persistence: FaissPersistenceConfig::default(),
55 optimization: FaissOptimizationConfig::default(),
56 }
57 }
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub enum FaissIndexType {
63 FlatL2,
65 FlatIP,
67 IvfFlat,
69 IvfPq,
71 IvfSq,
73 HnswFlat,
75 Lsh,
77 Auto,
79 Custom(String),
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct FaissPersistenceConfig {
86 pub index_directory: PathBuf,
88 pub auto_save: bool,
90 pub save_interval: u64,
92 pub compression: bool,
94 pub backup: FaissBackupConfig,
96}
97
98impl Default for FaissPersistenceConfig {
99 fn default() -> Self {
100 Self {
101 index_directory: PathBuf::from("./faiss_indices"),
102 auto_save: true,
103 save_interval: 300, compression: true,
105 backup: FaissBackupConfig::default(),
106 }
107 }
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct FaissBackupConfig {
113 pub enabled: bool,
115 pub backup_directory: PathBuf,
117 pub max_versions: usize,
119 pub backup_frequency: u64,
121}
122
123impl Default for FaissBackupConfig {
124 fn default() -> Self {
125 Self {
126 enabled: true,
127 backup_directory: PathBuf::from("./faiss_backups"),
128 max_versions: 5,
129 backup_frequency: 3600, }
131 }
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct FaissOptimizationConfig {
137 pub auto_optimize: bool,
139 pub optimization_frequency: usize,
141 pub dynamic_tuning: bool,
143 pub monitoring: FaissMonitoringConfig,
145}
146
147impl Default for FaissOptimizationConfig {
148 fn default() -> Self {
149 Self {
150 auto_optimize: true,
151 optimization_frequency: 100000,
152 dynamic_tuning: true,
153 monitoring: FaissMonitoringConfig::default(),
154 }
155 }
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct FaissMonitoringConfig {
161 pub enabled: bool,
163 pub collection_interval: u64,
165 pub track_memory: bool,
167 pub track_queries: bool,
169}
170
171impl Default for FaissMonitoringConfig {
172 fn default() -> Self {
173 Self {
174 enabled: true,
175 collection_interval: 60,
176 track_memory: true,
177 track_queries: true,
178 }
179 }
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct FaissSearchParams {
185 pub k: usize,
187 pub nprobe: Option<usize>,
189 pub hnsw_ef: Option<usize>,
191 pub exact_search: bool,
193 pub timeout_ms: Option<u64>,
195}
196
197impl Default for FaissSearchParams {
198 fn default() -> Self {
199 Self {
200 k: 10,
201 nprobe: Some(64),
202 hnsw_ef: Some(128),
203 exact_search: false,
204 timeout_ms: Some(5000),
205 }
206 }
207}
208
209pub struct FaissIndex {
211 config: FaissConfig,
213 index_handle: Arc<Mutex<Option<FaissIndexHandle>>>,
215 vectors: Arc<RwLock<Vec<Vec<f32>>>>,
217 metadata: Arc<RwLock<HashMap<usize, VectorMetadata>>>,
219 stats: Arc<RwLock<FaissStatistics>>,
221 training_state: Arc<RwLock<TrainingState>>,
223}
224
225#[derive(Debug)]
227pub struct FaissIndexHandle {
228 pub index_type: String,
230 pub num_vectors: usize,
232 pub dimension: usize,
234 pub is_trained: bool,
236 pub gpu_device: Option<u32>,
238}
239
240#[derive(Debug, Clone, Serialize, Deserialize)]
242pub struct VectorMetadata {
243 pub id: String,
245 pub timestamp: std::time::SystemTime,
247 pub norm: f32,
249 pub attributes: HashMap<String, String>,
251}
252
253#[derive(Debug, Clone)]
255pub struct TrainingState {
256 pub is_trained: bool,
258 pub training_progress: f32,
260 pub training_start: Option<std::time::Instant>,
262 pub training_vectors_count: usize,
264}
265
266impl Default for TrainingState {
267 fn default() -> Self {
268 Self {
269 is_trained: false,
270 training_progress: 0.0,
271 training_start: None,
272 training_vectors_count: 0,
273 }
274 }
275}
276
277#[derive(Debug, Clone, Default, Serialize, Deserialize)]
279pub struct FaissStatistics {
280 pub total_vectors: usize,
282 pub total_searches: usize,
284 pub avg_search_time_us: f64,
286 pub memory_usage_bytes: usize,
288 pub gpu_memory_usage_bytes: Option<usize>,
290 pub index_build_time_s: f64,
292 pub last_optimization: Option<std::time::SystemTime>,
294 pub performance_history: Vec<PerformanceSnapshot>,
296}
297
298#[derive(Debug, Clone, Serialize, Deserialize)]
300pub struct PerformanceSnapshot {
301 pub timestamp: std::time::SystemTime,
303 pub search_latency_p50: f64,
305 pub search_latency_p95: f64,
306 pub search_latency_p99: f64,
307 pub throughput_qps: f64,
309 pub memory_usage_mb: f64,
311 pub gpu_utilization: Option<f32>,
313}
314
315impl FaissIndex {
316 pub fn new(config: FaissConfig) -> Result<Self> {
318 let span = span!(Level::INFO, "faiss_index_new");
319 let _enter = span.enter();
320
321 Self::validate_config(&config)?;
323
324 let index = Self {
325 config: config.clone(),
326 index_handle: Arc::new(Mutex::new(None)),
327 vectors: Arc::new(RwLock::new(Vec::new())),
328 metadata: Arc::new(RwLock::new(HashMap::new())),
329 stats: Arc::new(RwLock::new(FaissStatistics::default())),
330 training_state: Arc::new(RwLock::new(TrainingState::default())),
331 };
332
333 index.initialize_faiss_index()?;
335
336 info!(
337 "Created FAISS index with type {:?}, dimension {}",
338 config.index_type, config.dimension
339 );
340
341 Ok(index)
342 }
343
344 fn validate_config(config: &FaissConfig) -> Result<()> {
346 if config.dimension == 0 {
347 return Err(AnyhowError::msg("Dimension must be greater than 0"));
348 }
349
350 if config.training_sample_size == 0 {
351 return Err(AnyhowError::msg(
352 "Training sample size must be greater than 0",
353 ));
354 }
355
356 match &config.index_type {
358 FaissIndexType::IvfFlat | FaissIndexType::IvfSq if config.num_clusters.is_none() => {
359 return Err(AnyhowError::msg(
360 "IVF indices require num_clusters to be set",
361 ));
362 }
363 FaissIndexType::IvfPq => {
364 if config.num_clusters.is_none() {
365 return Err(AnyhowError::msg(
366 "IVF indices require num_clusters to be set",
367 ));
368 }
369 if config.num_subquantizers.is_none() {
370 return Err(AnyhowError::msg(
371 "IVF-PQ requires num_subquantizers to be set",
372 ));
373 }
374 if config.bits_per_subquantizer.is_none() {
375 return Err(AnyhowError::msg(
376 "IVF-PQ requires bits_per_subquantizer to be set",
377 ));
378 }
379 }
380 _ => {}
381 }
382
383 Ok(())
384 }
385
386 fn initialize_faiss_index(&self) -> Result<()> {
388 let span = span!(Level::DEBUG, "initialize_faiss_index");
389 let _enter = span.enter();
390
391 let index_type_str = self.faiss_index_string()?;
393
394 let handle = FaissIndexHandle {
395 index_type: index_type_str,
396 num_vectors: 0,
397 dimension: self.config.dimension,
398 is_trained: self.requires_training(),
399 gpu_device: if self.config.use_gpu {
400 Some(self.config.gpu_devices.first().copied().unwrap_or(0))
401 } else {
402 None
403 },
404 };
405
406 let mut index_handle = self
407 .index_handle
408 .lock()
409 .map_err(|_| AnyhowError::msg("Failed to acquire index handle lock"))?;
410 *index_handle = Some(handle);
411
412 debug!("Initialized FAISS index: {}", self.faiss_index_string()?);
413 Ok(())
414 }
415
416 fn requires_training(&self) -> bool {
418 !matches!(
419 self.config.index_type,
420 FaissIndexType::FlatL2 | FaissIndexType::FlatIP
421 )
422 }
423
424 fn faiss_index_string(&self) -> Result<String> {
426 let index_str = match &self.config.index_type {
427 FaissIndexType::FlatL2 => "Flat".to_string(),
428 FaissIndexType::FlatIP => "Flat".to_string(),
429 FaissIndexType::IvfFlat => {
430 let clusters = self.config.num_clusters.unwrap_or(1024);
431 format!("IVF{clusters},Flat")
432 }
433 FaissIndexType::IvfPq => {
434 let clusters = self.config.num_clusters.unwrap_or(1024);
435 let subq = self.config.num_subquantizers.unwrap_or(8);
436 let bits = self.config.bits_per_subquantizer.unwrap_or(8);
437 format!("IVF{clusters},PQ{subq}x{bits}")
438 }
439 FaissIndexType::IvfSq => {
440 let clusters = self.config.num_clusters.unwrap_or(1024);
441 format!("IVF{clusters},SQ8")
442 }
443 FaissIndexType::HnswFlat => "HNSW32,Flat".to_string(),
444 FaissIndexType::Lsh => "LSH".to_string(),
445 FaissIndexType::Auto => self.auto_select_index_type()?,
446 FaissIndexType::Custom(s) => s.clone(),
447 };
448
449 Ok(index_str)
450 }
451
452 fn auto_select_index_type(&self) -> Result<String> {
454 let num_vectors = {
455 let vectors = self
456 .vectors
457 .read()
458 .map_err(|_| AnyhowError::msg("Failed to acquire vectors lock"))?;
459 vectors.len()
460 };
461
462 let dimension = self.config.dimension;
463
464 let index_str = if num_vectors < 10000 {
466 "Flat".to_string()
468 } else if num_vectors < 1000000 {
469 let clusters = (num_vectors as f32).sqrt() as usize;
471 if dimension > 128 {
472 format!("IVF{clusters},PQ16x8")
473 } else {
474 format!("IVF{clusters},Flat")
475 }
476 } else {
477 let clusters = (num_vectors as f32).sqrt() as usize;
479 format!("IVF{},PQ{}x8", clusters, std::cmp::min(dimension / 4, 64))
480 };
481
482 debug!(
483 "Auto-selected FAISS index: {} for {} vectors, {} dimensions",
484 index_str, num_vectors, dimension
485 );
486
487 Ok(index_str)
488 }
489
490 pub fn train(&self, training_vectors: &[Vec<f32>]) -> Result<()> {
492 let span = span!(Level::INFO, "faiss_train");
493 let _enter = span.enter();
494
495 if !self.requires_training() {
496 debug!("Index type does not require training");
497 return Ok(());
498 }
499
500 {
502 let mut state = self
503 .training_state
504 .write()
505 .map_err(|_| AnyhowError::msg("Failed to acquire training state lock"))?;
506 state.training_start = Some(std::time::Instant::now());
507 state.training_vectors_count = training_vectors.len();
508 state.training_progress = 0.0;
509 }
510
511 if training_vectors.is_empty() {
513 return Err(AnyhowError::msg("Training vectors cannot be empty"));
514 }
515
516 for (i, vector) in training_vectors.iter().enumerate() {
517 if vector.len() != self.config.dimension {
518 return Err(AnyhowError::msg(format!(
519 "Training vector {} has dimension {}, expected {}",
520 i,
521 vector.len(),
522 self.config.dimension
523 )));
524 }
525 }
526
527 info!(
529 "Training FAISS index with {} vectors",
530 training_vectors.len()
531 );
532
533 for progress in 0..=10 {
535 std::thread::sleep(std::time::Duration::from_millis(100));
536 let mut state = self
537 .training_state
538 .write()
539 .map_err(|_| AnyhowError::msg("Failed to acquire training state lock"))?;
540 state.training_progress = progress as f32 / 10.0;
541 }
542
543 {
545 let mut state = self
546 .training_state
547 .write()
548 .map_err(|_| AnyhowError::msg("Failed to acquire training state lock"))?;
549 state.is_trained = true;
550 state.training_progress = 1.0;
551 }
552
553 {
555 let mut handle = self
556 .index_handle
557 .lock()
558 .map_err(|_| AnyhowError::msg("Failed to acquire index handle lock"))?;
559 if let Some(ref mut h) = *handle {
560 h.is_trained = true;
561 }
562 }
563
564 info!("FAISS index training completed successfully");
565 Ok(())
566 }
567
568 pub fn add_vectors(&self, vectors: Vec<Vec<f32>>, ids: Vec<String>) -> Result<()> {
570 let span = span!(Level::DEBUG, "faiss_add_vectors");
571 let _enter = span.enter();
572
573 if vectors.len() != ids.len() {
574 return Err(AnyhowError::msg(
575 "Number of vectors must match number of IDs",
576 ));
577 }
578
579 if self.requires_training() {
581 let state = self
582 .training_state
583 .read()
584 .map_err(|_| AnyhowError::msg("Failed to acquire training state lock"))?;
585 if !state.is_trained {
586 return Err(AnyhowError::msg(
587 "Index must be trained before adding vectors",
588 ));
589 }
590 }
591
592 for (i, vector) in vectors.iter().enumerate() {
594 if vector.len() != self.config.dimension {
595 return Err(AnyhowError::msg(format!(
596 "Vector {} has dimension {}, expected {}",
597 i,
598 vector.len(),
599 self.config.dimension
600 )));
601 }
602 }
603
604 let start_time = std::time::Instant::now();
605
606 let mut vec_storage = self
608 .vectors
609 .write()
610 .map_err(|_| AnyhowError::msg("Failed to acquire vectors lock"))?;
611 let mut metadata_storage = self
612 .metadata
613 .write()
614 .map_err(|_| AnyhowError::msg("Failed to acquire metadata lock"))?;
615
616 for (vector, id) in vectors.iter().zip(ids.iter()) {
617 let index = vec_storage.len();
618 vec_storage.push(vector.clone());
619
620 let norm = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
621 let metadata = VectorMetadata {
622 id: id.clone(),
623 timestamp: std::time::SystemTime::now(),
624 norm,
625 attributes: HashMap::new(),
626 };
627 metadata_storage.insert(index, metadata);
628 }
629
630 {
632 let mut stats = self
633 .stats
634 .write()
635 .map_err(|_| AnyhowError::msg("Failed to acquire stats lock"))?;
636 stats.total_vectors += vectors.len();
637 stats.index_build_time_s += start_time.elapsed().as_secs_f64();
638 }
639
640 {
642 let mut handle = self
643 .index_handle
644 .lock()
645 .map_err(|_| AnyhowError::msg("Failed to acquire index handle lock"))?;
646 if let Some(ref mut h) = *handle {
647 h.num_vectors += vectors.len();
648 }
649 }
650
651 debug!("Added {} vectors to FAISS index", vectors.len());
652 Ok(())
653 }
654
655 pub fn search(
657 &self,
658 query_vector: &[f32],
659 params: &FaissSearchParams,
660 ) -> Result<Vec<(String, f32)>> {
661 let span = span!(Level::DEBUG, "faiss_search");
662 let _enter = span.enter();
663
664 if query_vector.len() != self.config.dimension {
665 return Err(AnyhowError::msg(format!(
666 "Query vector has dimension {}, expected {}",
667 query_vector.len(),
668 self.config.dimension
669 )));
670 }
671
672 let start_time = std::time::Instant::now();
673
674 let results = self.simulate_search(query_vector, params)?;
676
677 {
679 let mut stats = self
680 .stats
681 .write()
682 .map_err(|_| AnyhowError::msg("Failed to acquire stats lock"))?;
683 stats.total_searches += 1;
684 let search_time_us = start_time.elapsed().as_micros() as f64;
685 stats.avg_search_time_us =
686 (stats.avg_search_time_us * (stats.total_searches - 1) as f64 + search_time_us)
687 / stats.total_searches as f64;
688 }
689
690 debug!("FAISS search completed in {:?}", start_time.elapsed());
691 Ok(results)
692 }
693
694 fn simulate_search(
696 &self,
697 query_vector: &[f32],
698 params: &FaissSearchParams,
699 ) -> Result<Vec<(String, f32)>> {
700 let vectors = self
701 .vectors
702 .read()
703 .map_err(|_| AnyhowError::msg("Failed to acquire vectors lock"))?;
704 let metadata = self
705 .metadata
706 .read()
707 .map_err(|_| AnyhowError::msg("Failed to acquire metadata lock"))?;
708
709 let mut results = Vec::new();
710
711 for (i, vector) in vectors.iter().enumerate() {
713 let distance = self.compute_distance(query_vector, vector);
714 if let Some(meta) = metadata.get(&i) {
715 results.push((meta.id.clone(), distance));
716 }
717 }
718
719 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
721 results.truncate(params.k);
722
723 Ok(results)
724 }
725
726 fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
728 match self.config.index_type {
729 FaissIndexType::FlatL2
730 | FaissIndexType::IvfFlat
731 | FaissIndexType::IvfPq
732 | FaissIndexType::IvfSq => {
733 a.iter()
735 .zip(b.iter())
736 .map(|(x, y)| (x - y).powi(2))
737 .sum::<f32>()
738 .sqrt()
739 }
740 FaissIndexType::FlatIP => {
741 -a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>()
743 }
744 _ => {
745 a.iter()
747 .zip(b.iter())
748 .map(|(x, y)| (x - y).powi(2))
749 .sum::<f32>()
750 .sqrt()
751 }
752 }
753 }
754
755 pub fn get_statistics(&self) -> Result<FaissStatistics> {
757 let stats = self
758 .stats
759 .read()
760 .map_err(|_| AnyhowError::msg("Failed to acquire stats lock"))?;
761 Ok(stats.clone())
762 }
763
764 pub fn save_index(&self, path: &Path) -> Result<()> {
766 let span = span!(Level::INFO, "faiss_save_index");
767 let _enter = span.enter();
768
769 if let Some(parent) = path.parent() {
771 std::fs::create_dir_all(parent)
772 .with_context(|| format!("Failed to create directory: {parent:?}"))?;
773 }
774
775 info!("Saving FAISS index to {:?}", path);
777
778 std::thread::sleep(std::time::Duration::from_millis(100));
780
781 Ok(())
782 }
783
784 pub fn load_index(&self, path: &Path) -> Result<()> {
786 let span = span!(Level::INFO, "faiss_load_index");
787 let _enter = span.enter();
788
789 if !path.exists() {
790 return Err(AnyhowError::msg(format!(
791 "Index file does not exist: {path:?}"
792 )));
793 }
794
795 info!("Loading FAISS index from {:?}", path);
797
798 std::thread::sleep(std::time::Duration::from_millis(100));
800
801 Ok(())
802 }
803
804 pub fn optimize(&self) -> Result<()> {
806 let span = span!(Level::INFO, "faiss_optimize");
807 let _enter = span.enter();
808
809 {
811 let mut stats = self
812 .stats
813 .write()
814 .map_err(|_| AnyhowError::msg("Failed to acquire stats lock"))?;
815 stats.last_optimization = Some(std::time::SystemTime::now());
816 }
817
818 info!("FAISS index optimization completed");
819 Ok(())
820 }
821
822 pub fn get_memory_usage(&self) -> Result<usize> {
824 let vectors = self
825 .vectors
826 .read()
827 .map_err(|_| AnyhowError::msg("Failed to acquire vectors lock"))?;
828
829 let vector_memory = vectors.len() * self.config.dimension * std::mem::size_of::<f32>();
830 let metadata_memory = vectors.len() * std::mem::size_of::<VectorMetadata>();
831
832 Ok(vector_memory + metadata_memory)
833 }
834
835 pub fn dimension(&self) -> usize {
837 self.config.dimension
838 }
839
840 pub fn size(&self) -> usize {
842 self.vectors.read().map(|v| v.len()).unwrap_or(0)
843 }
844}
845
846impl VectorIndex for FaissIndex {
847 fn insert(&mut self, uri: String, vector: crate::Vector) -> Result<()> {
848 self.add_vectors(vec![vector.as_f32()], vec![uri])
849 }
850
851 fn search_knn(&self, query: &crate::Vector, k: usize) -> Result<Vec<(String, f32)>> {
852 let params = FaissSearchParams {
853 k,
854 ..Default::default()
855 };
856 self.search(&query.as_f32(), ¶ms)
857 }
858
859 fn search_threshold(
860 &self,
861 query: &crate::Vector,
862 threshold: f32,
863 ) -> Result<Vec<(String, f32)>> {
864 let params = FaissSearchParams {
865 k: 1000, ..Default::default()
868 };
869 let results = self.search(&query.as_f32(), ¶ms)?;
870 Ok(results
871 .into_iter()
872 .filter(|(_, score)| *score >= threshold)
873 .collect())
874 }
875
876 fn get_vector(&self, _uri: &str) -> Option<&crate::Vector> {
877 None
880 }
881}
882
883pub struct FaissFactory;
885
886impl FaissFactory {
887 pub fn create_optimized_index(
889 dimension: usize,
890 expected_size: usize,
891 use_gpu: bool,
892 ) -> Result<FaissIndex> {
893 let index_type = if expected_size < 10000 {
894 FaissIndexType::FlatL2
895 } else if expected_size < 1000000 {
896 FaissIndexType::IvfFlat
897 } else {
898 FaissIndexType::IvfPq
899 };
900
901 let config = FaissConfig {
902 index_type,
903 dimension,
904 training_sample_size: std::cmp::min(expected_size / 10, 100000),
905 num_clusters: Some((expected_size as f32).sqrt() as usize),
906 use_gpu,
907 ..Default::default()
908 };
909
910 FaissIndex::new(config)
911 }
912
913 pub fn create_gpu_index(dimension: usize, gpu_devices: Vec<u32>) -> Result<FaissIndex> {
915 let config = FaissConfig {
916 dimension,
917 use_gpu: true,
918 gpu_devices,
919 index_type: FaissIndexType::Auto,
920 ..Default::default()
921 };
922
923 FaissIndex::new(config)
924 }
925}
926
927#[cfg(test)]
928mod tests {
929 use super::*;
930 use anyhow::Result;
931
932 #[test]
933 fn test_faiss_index_creation() -> Result<()> {
934 let config = FaissConfig {
935 dimension: 128,
936 index_type: FaissIndexType::FlatL2,
937 ..Default::default()
938 };
939
940 let index = FaissIndex::new(config)?;
941 assert_eq!(index.dimension(), 128);
942 assert_eq!(index.size(), 0);
943 Ok(())
944 }
945
946 #[test]
947 fn test_faiss_add_and_search() -> Result<()> {
948 let config = FaissConfig {
949 dimension: 4,
950 index_type: FaissIndexType::FlatL2,
951 ..Default::default()
952 };
953
954 let index = FaissIndex::new(config)?;
955
956 let vectors = vec![
958 vec![1.0, 0.0, 0.0, 0.0],
959 vec![0.0, 1.0, 0.0, 0.0],
960 vec![0.0, 0.0, 1.0, 0.0],
961 ];
962 let ids = vec!["vec1".to_string(), "vec2".to_string(), "vec3".to_string()];
963
964 index.add_vectors(vectors, ids)?;
965 assert_eq!(index.size(), 3);
966
967 let query = vec![1.0, 0.1, 0.0, 0.0];
969 let params = FaissSearchParams {
970 k: 2,
971 ..Default::default()
972 };
973 let results = index.search(&query, ¶ms)?;
974
975 assert_eq!(results.len(), 2);
976 assert_eq!(results[0].0, "vec1"); Ok(())
978 }
979
980 #[test]
981 fn test_faiss_training() -> Result<()> {
982 let config = FaissConfig {
983 dimension: 4,
984 index_type: FaissIndexType::IvfFlat,
985 num_clusters: Some(2),
986 training_sample_size: 10,
987 ..Default::default()
988 };
989
990 let index = FaissIndex::new(config)?;
991
992 let training_vectors: Vec<Vec<f32>> = (0..10)
994 .map(|i| vec![i as f32, (i % 2) as f32, 0.0, 0.0])
995 .collect();
996
997 index.train(&training_vectors)?;
998
999 let state = index
1000 .training_state
1001 .read()
1002 .expect("training_state lock not poisoned");
1003 assert!(state.is_trained);
1004 assert_eq!(state.training_progress, 1.0);
1005 Ok(())
1006 }
1007
1008 #[test]
1009 fn test_faiss_factory() -> Result<()> {
1010 let index = FaissFactory::create_optimized_index(64, 1000, false)?;
1011 assert_eq!(index.dimension(), 64);
1012
1013 let gpu_index = FaissFactory::create_gpu_index(128, vec![0])?;
1014 assert_eq!(gpu_index.dimension(), 128);
1015 assert!(gpu_index.config.use_gpu);
1016 Ok(())
1017 }
1018
1019 #[test]
1020 fn test_faiss_auto_index_selection() -> Result<()> {
1021 let config = FaissConfig {
1022 dimension: 64,
1023 index_type: FaissIndexType::Auto,
1024 ..Default::default()
1025 };
1026
1027 let index = FaissIndex::new(config)?;
1028 let index_str = index.faiss_index_string()?;
1029
1030 assert_eq!(index_str, "Flat");
1032 Ok(())
1033 }
1034}