use crate::index::VectorIndex;
use anyhow::{Context, Error as AnyhowError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex, RwLock};
use tracing::{debug, info, span, Level};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FaissConfig {
pub index_type: FaissIndexType,
pub dimension: usize,
pub training_sample_size: usize,
pub num_clusters: Option<usize>,
pub num_subquantizers: Option<usize>,
pub bits_per_subquantizer: Option<u8>,
pub use_gpu: bool,
pub gpu_devices: Vec<u32>,
pub enable_mmap: bool,
pub persistence: FaissPersistenceConfig,
pub optimization: FaissOptimizationConfig,
}
impl Default for FaissConfig {
fn default() -> Self {
Self {
index_type: FaissIndexType::FlatL2,
dimension: 384,
training_sample_size: 10000,
num_clusters: Some(1024),
num_subquantizers: Some(8),
bits_per_subquantizer: Some(8),
use_gpu: false,
gpu_devices: vec![0],
enable_mmap: true,
persistence: FaissPersistenceConfig::default(),
optimization: FaissOptimizationConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FaissIndexType {
FlatL2,
FlatIP,
IvfFlat,
IvfPq,
IvfSq,
HnswFlat,
Lsh,
Auto,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FaissPersistenceConfig {
pub index_directory: PathBuf,
pub auto_save: bool,
pub save_interval: u64,
pub compression: bool,
pub backup: FaissBackupConfig,
}
impl Default for FaissPersistenceConfig {
fn default() -> Self {
Self {
index_directory: PathBuf::from("./faiss_indices"),
auto_save: true,
save_interval: 300, compression: true,
backup: FaissBackupConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FaissBackupConfig {
pub enabled: bool,
pub backup_directory: PathBuf,
pub max_versions: usize,
pub backup_frequency: u64,
}
impl Default for FaissBackupConfig {
fn default() -> Self {
Self {
enabled: true,
backup_directory: PathBuf::from("./faiss_backups"),
max_versions: 5,
backup_frequency: 3600, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FaissOptimizationConfig {
pub auto_optimize: bool,
pub optimization_frequency: usize,
pub dynamic_tuning: bool,
pub monitoring: FaissMonitoringConfig,
}
impl Default for FaissOptimizationConfig {
fn default() -> Self {
Self {
auto_optimize: true,
optimization_frequency: 100000,
dynamic_tuning: true,
monitoring: FaissMonitoringConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FaissMonitoringConfig {
pub enabled: bool,
pub collection_interval: u64,
pub track_memory: bool,
pub track_queries: bool,
}
impl Default for FaissMonitoringConfig {
fn default() -> Self {
Self {
enabled: true,
collection_interval: 60,
track_memory: true,
track_queries: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FaissSearchParams {
pub k: usize,
pub nprobe: Option<usize>,
pub hnsw_ef: Option<usize>,
pub exact_search: bool,
pub timeout_ms: Option<u64>,
}
impl Default for FaissSearchParams {
fn default() -> Self {
Self {
k: 10,
nprobe: Some(64),
hnsw_ef: Some(128),
exact_search: false,
timeout_ms: Some(5000),
}
}
}
pub struct FaissIndex {
config: FaissConfig,
index_handle: Arc<Mutex<Option<FaissIndexHandle>>>,
vectors: Arc<RwLock<Vec<Vec<f32>>>>,
metadata: Arc<RwLock<HashMap<usize, VectorMetadata>>>,
stats: Arc<RwLock<FaissStatistics>>,
training_state: Arc<RwLock<TrainingState>>,
}
#[derive(Debug)]
pub struct FaissIndexHandle {
pub index_type: String,
pub num_vectors: usize,
pub dimension: usize,
pub is_trained: bool,
pub gpu_device: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorMetadata {
pub id: String,
pub timestamp: std::time::SystemTime,
pub norm: f32,
pub attributes: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct TrainingState {
pub is_trained: bool,
pub training_progress: f32,
pub training_start: Option<std::time::Instant>,
pub training_vectors_count: usize,
}
impl Default for TrainingState {
fn default() -> Self {
Self {
is_trained: false,
training_progress: 0.0,
training_start: None,
training_vectors_count: 0,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct FaissStatistics {
pub total_vectors: usize,
pub total_searches: usize,
pub avg_search_time_us: f64,
pub memory_usage_bytes: usize,
pub gpu_memory_usage_bytes: Option<usize>,
pub index_build_time_s: f64,
pub last_optimization: Option<std::time::SystemTime>,
pub performance_history: Vec<PerformanceSnapshot>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceSnapshot {
pub timestamp: std::time::SystemTime,
pub search_latency_p50: f64,
pub search_latency_p95: f64,
pub search_latency_p99: f64,
pub throughput_qps: f64,
pub memory_usage_mb: f64,
pub gpu_utilization: Option<f32>,
}
impl FaissIndex {
pub fn new(config: FaissConfig) -> Result<Self> {
let span = span!(Level::INFO, "faiss_index_new");
let _enter = span.enter();
Self::validate_config(&config)?;
let index = Self {
config: config.clone(),
index_handle: Arc::new(Mutex::new(None)),
vectors: Arc::new(RwLock::new(Vec::new())),
metadata: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(FaissStatistics::default())),
training_state: Arc::new(RwLock::new(TrainingState::default())),
};
index.initialize_faiss_index()?;
info!(
"Created FAISS index with type {:?}, dimension {}",
config.index_type, config.dimension
);
Ok(index)
}
fn validate_config(config: &FaissConfig) -> Result<()> {
if config.dimension == 0 {
return Err(AnyhowError::msg("Dimension must be greater than 0"));
}
if config.training_sample_size == 0 {
return Err(AnyhowError::msg(
"Training sample size must be greater than 0",
));
}
match &config.index_type {
FaissIndexType::IvfFlat | FaissIndexType::IvfSq if config.num_clusters.is_none() => {
return Err(AnyhowError::msg(
"IVF indices require num_clusters to be set",
));
}
FaissIndexType::IvfPq => {
if config.num_clusters.is_none() {
return Err(AnyhowError::msg(
"IVF indices require num_clusters to be set",
));
}
if config.num_subquantizers.is_none() {
return Err(AnyhowError::msg(
"IVF-PQ requires num_subquantizers to be set",
));
}
if config.bits_per_subquantizer.is_none() {
return Err(AnyhowError::msg(
"IVF-PQ requires bits_per_subquantizer to be set",
));
}
}
_ => {}
}
Ok(())
}
fn initialize_faiss_index(&self) -> Result<()> {
let span = span!(Level::DEBUG, "initialize_faiss_index");
let _enter = span.enter();
let index_type_str = self.faiss_index_string()?;
let handle = FaissIndexHandle {
index_type: index_type_str,
num_vectors: 0,
dimension: self.config.dimension,
is_trained: self.requires_training(),
gpu_device: if self.config.use_gpu {
Some(self.config.gpu_devices.first().copied().unwrap_or(0))
} else {
None
},
};
let mut index_handle = self
.index_handle
.lock()
.map_err(|_| AnyhowError::msg("Failed to acquire index handle lock"))?;
*index_handle = Some(handle);
debug!("Initialized FAISS index: {}", self.faiss_index_string()?);
Ok(())
}
fn requires_training(&self) -> bool {
!matches!(
self.config.index_type,
FaissIndexType::FlatL2 | FaissIndexType::FlatIP
)
}
fn faiss_index_string(&self) -> Result<String> {
let index_str = match &self.config.index_type {
FaissIndexType::FlatL2 => "Flat".to_string(),
FaissIndexType::FlatIP => "Flat".to_string(),
FaissIndexType::IvfFlat => {
let clusters = self.config.num_clusters.unwrap_or(1024);
format!("IVF{clusters},Flat")
}
FaissIndexType::IvfPq => {
let clusters = self.config.num_clusters.unwrap_or(1024);
let subq = self.config.num_subquantizers.unwrap_or(8);
let bits = self.config.bits_per_subquantizer.unwrap_or(8);
format!("IVF{clusters},PQ{subq}x{bits}")
}
FaissIndexType::IvfSq => {
let clusters = self.config.num_clusters.unwrap_or(1024);
format!("IVF{clusters},SQ8")
}
FaissIndexType::HnswFlat => "HNSW32,Flat".to_string(),
FaissIndexType::Lsh => "LSH".to_string(),
FaissIndexType::Auto => self.auto_select_index_type()?,
FaissIndexType::Custom(s) => s.clone(),
};
Ok(index_str)
}
fn auto_select_index_type(&self) -> Result<String> {
let num_vectors = {
let vectors = self
.vectors
.read()
.map_err(|_| AnyhowError::msg("Failed to acquire vectors lock"))?;
vectors.len()
};
let dimension = self.config.dimension;
let index_str = if num_vectors < 10000 {
"Flat".to_string()
} else if num_vectors < 1000000 {
let clusters = (num_vectors as f32).sqrt() as usize;
if dimension > 128 {
format!("IVF{clusters},PQ16x8")
} else {
format!("IVF{clusters},Flat")
}
} else {
let clusters = (num_vectors as f32).sqrt() as usize;
format!("IVF{},PQ{}x8", clusters, std::cmp::min(dimension / 4, 64))
};
debug!(
"Auto-selected FAISS index: {} for {} vectors, {} dimensions",
index_str, num_vectors, dimension
);
Ok(index_str)
}
pub fn train(&self, training_vectors: &[Vec<f32>]) -> Result<()> {
let span = span!(Level::INFO, "faiss_train");
let _enter = span.enter();
if !self.requires_training() {
debug!("Index type does not require training");
return Ok(());
}
{
let mut state = self
.training_state
.write()
.map_err(|_| AnyhowError::msg("Failed to acquire training state lock"))?;
state.training_start = Some(std::time::Instant::now());
state.training_vectors_count = training_vectors.len();
state.training_progress = 0.0;
}
if training_vectors.is_empty() {
return Err(AnyhowError::msg("Training vectors cannot be empty"));
}
for (i, vector) in training_vectors.iter().enumerate() {
if vector.len() != self.config.dimension {
return Err(AnyhowError::msg(format!(
"Training vector {} has dimension {}, expected {}",
i,
vector.len(),
self.config.dimension
)));
}
}
info!(
"Training FAISS index with {} vectors",
training_vectors.len()
);
for progress in 0..=10 {
std::thread::sleep(std::time::Duration::from_millis(100));
let mut state = self
.training_state
.write()
.map_err(|_| AnyhowError::msg("Failed to acquire training state lock"))?;
state.training_progress = progress as f32 / 10.0;
}
{
let mut state = self
.training_state
.write()
.map_err(|_| AnyhowError::msg("Failed to acquire training state lock"))?;
state.is_trained = true;
state.training_progress = 1.0;
}
{
let mut handle = self
.index_handle
.lock()
.map_err(|_| AnyhowError::msg("Failed to acquire index handle lock"))?;
if let Some(ref mut h) = *handle {
h.is_trained = true;
}
}
info!("FAISS index training completed successfully");
Ok(())
}
pub fn add_vectors(&self, vectors: Vec<Vec<f32>>, ids: Vec<String>) -> Result<()> {
let span = span!(Level::DEBUG, "faiss_add_vectors");
let _enter = span.enter();
if vectors.len() != ids.len() {
return Err(AnyhowError::msg(
"Number of vectors must match number of IDs",
));
}
if self.requires_training() {
let state = self
.training_state
.read()
.map_err(|_| AnyhowError::msg("Failed to acquire training state lock"))?;
if !state.is_trained {
return Err(AnyhowError::msg(
"Index must be trained before adding vectors",
));
}
}
for (i, vector) in vectors.iter().enumerate() {
if vector.len() != self.config.dimension {
return Err(AnyhowError::msg(format!(
"Vector {} has dimension {}, expected {}",
i,
vector.len(),
self.config.dimension
)));
}
}
let start_time = std::time::Instant::now();
let mut vec_storage = self
.vectors
.write()
.map_err(|_| AnyhowError::msg("Failed to acquire vectors lock"))?;
let mut metadata_storage = self
.metadata
.write()
.map_err(|_| AnyhowError::msg("Failed to acquire metadata lock"))?;
for (vector, id) in vectors.iter().zip(ids.iter()) {
let index = vec_storage.len();
vec_storage.push(vector.clone());
let norm = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
let metadata = VectorMetadata {
id: id.clone(),
timestamp: std::time::SystemTime::now(),
norm,
attributes: HashMap::new(),
};
metadata_storage.insert(index, metadata);
}
{
let mut stats = self
.stats
.write()
.map_err(|_| AnyhowError::msg("Failed to acquire stats lock"))?;
stats.total_vectors += vectors.len();
stats.index_build_time_s += start_time.elapsed().as_secs_f64();
}
{
let mut handle = self
.index_handle
.lock()
.map_err(|_| AnyhowError::msg("Failed to acquire index handle lock"))?;
if let Some(ref mut h) = *handle {
h.num_vectors += vectors.len();
}
}
debug!("Added {} vectors to FAISS index", vectors.len());
Ok(())
}
pub fn search(
&self,
query_vector: &[f32],
params: &FaissSearchParams,
) -> Result<Vec<(String, f32)>> {
let span = span!(Level::DEBUG, "faiss_search");
let _enter = span.enter();
if query_vector.len() != self.config.dimension {
return Err(AnyhowError::msg(format!(
"Query vector has dimension {}, expected {}",
query_vector.len(),
self.config.dimension
)));
}
let start_time = std::time::Instant::now();
let results = self.simulate_search(query_vector, params)?;
{
let mut stats = self
.stats
.write()
.map_err(|_| AnyhowError::msg("Failed to acquire stats lock"))?;
stats.total_searches += 1;
let search_time_us = start_time.elapsed().as_micros() as f64;
stats.avg_search_time_us =
(stats.avg_search_time_us * (stats.total_searches - 1) as f64 + search_time_us)
/ stats.total_searches as f64;
}
debug!("FAISS search completed in {:?}", start_time.elapsed());
Ok(results)
}
fn simulate_search(
&self,
query_vector: &[f32],
params: &FaissSearchParams,
) -> Result<Vec<(String, f32)>> {
let vectors = self
.vectors
.read()
.map_err(|_| AnyhowError::msg("Failed to acquire vectors lock"))?;
let metadata = self
.metadata
.read()
.map_err(|_| AnyhowError::msg("Failed to acquire metadata lock"))?;
let mut results = Vec::new();
for (i, vector) in vectors.iter().enumerate() {
let distance = self.compute_distance(query_vector, vector);
if let Some(meta) = metadata.get(&i) {
results.push((meta.id.clone(), distance));
}
}
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(params.k);
Ok(results)
}
fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
match self.config.index_type {
FaissIndexType::FlatL2
| FaissIndexType::IvfFlat
| FaissIndexType::IvfPq
| FaissIndexType::IvfSq => {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
FaissIndexType::FlatIP => {
-a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>()
}
_ => {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
}
}
pub fn get_statistics(&self) -> Result<FaissStatistics> {
let stats = self
.stats
.read()
.map_err(|_| AnyhowError::msg("Failed to acquire stats lock"))?;
Ok(stats.clone())
}
pub fn save_index(&self, path: &Path) -> Result<()> {
let span = span!(Level::INFO, "faiss_save_index");
let _enter = span.enter();
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("Failed to create directory: {parent:?}"))?;
}
info!("Saving FAISS index to {:?}", path);
std::thread::sleep(std::time::Duration::from_millis(100));
Ok(())
}
pub fn load_index(&self, path: &Path) -> Result<()> {
let span = span!(Level::INFO, "faiss_load_index");
let _enter = span.enter();
if !path.exists() {
return Err(AnyhowError::msg(format!(
"Index file does not exist: {path:?}"
)));
}
info!("Loading FAISS index from {:?}", path);
std::thread::sleep(std::time::Duration::from_millis(100));
Ok(())
}
pub fn optimize(&self) -> Result<()> {
let span = span!(Level::INFO, "faiss_optimize");
let _enter = span.enter();
{
let mut stats = self
.stats
.write()
.map_err(|_| AnyhowError::msg("Failed to acquire stats lock"))?;
stats.last_optimization = Some(std::time::SystemTime::now());
}
info!("FAISS index optimization completed");
Ok(())
}
pub fn get_memory_usage(&self) -> Result<usize> {
let vectors = self
.vectors
.read()
.map_err(|_| AnyhowError::msg("Failed to acquire vectors lock"))?;
let vector_memory = vectors.len() * self.config.dimension * std::mem::size_of::<f32>();
let metadata_memory = vectors.len() * std::mem::size_of::<VectorMetadata>();
Ok(vector_memory + metadata_memory)
}
pub fn dimension(&self) -> usize {
self.config.dimension
}
pub fn size(&self) -> usize {
self.vectors.read().map(|v| v.len()).unwrap_or(0)
}
}
impl VectorIndex for FaissIndex {
fn insert(&mut self, uri: String, vector: crate::Vector) -> Result<()> {
self.add_vectors(vec![vector.as_f32()], vec![uri])
}
fn search_knn(&self, query: &crate::Vector, k: usize) -> Result<Vec<(String, f32)>> {
let params = FaissSearchParams {
k,
..Default::default()
};
self.search(&query.as_f32(), ¶ms)
}
fn search_threshold(
&self,
query: &crate::Vector,
threshold: f32,
) -> Result<Vec<(String, f32)>> {
let params = FaissSearchParams {
k: 1000, ..Default::default()
};
let results = self.search(&query.as_f32(), ¶ms)?;
Ok(results
.into_iter()
.filter(|(_, score)| *score >= threshold)
.collect())
}
fn get_vector(&self, _uri: &str) -> Option<&crate::Vector> {
None
}
}
pub struct FaissFactory;
impl FaissFactory {
pub fn create_optimized_index(
dimension: usize,
expected_size: usize,
use_gpu: bool,
) -> Result<FaissIndex> {
let index_type = if expected_size < 10000 {
FaissIndexType::FlatL2
} else if expected_size < 1000000 {
FaissIndexType::IvfFlat
} else {
FaissIndexType::IvfPq
};
let config = FaissConfig {
index_type,
dimension,
training_sample_size: std::cmp::min(expected_size / 10, 100000),
num_clusters: Some((expected_size as f32).sqrt() as usize),
use_gpu,
..Default::default()
};
FaissIndex::new(config)
}
pub fn create_gpu_index(dimension: usize, gpu_devices: Vec<u32>) -> Result<FaissIndex> {
let config = FaissConfig {
dimension,
use_gpu: true,
gpu_devices,
index_type: FaissIndexType::Auto,
..Default::default()
};
FaissIndex::new(config)
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
#[test]
fn test_faiss_index_creation() -> Result<()> {
let config = FaissConfig {
dimension: 128,
index_type: FaissIndexType::FlatL2,
..Default::default()
};
let index = FaissIndex::new(config)?;
assert_eq!(index.dimension(), 128);
assert_eq!(index.size(), 0);
Ok(())
}
#[test]
fn test_faiss_add_and_search() -> Result<()> {
let config = FaissConfig {
dimension: 4,
index_type: FaissIndexType::FlatL2,
..Default::default()
};
let index = FaissIndex::new(config)?;
let vectors = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
];
let ids = vec!["vec1".to_string(), "vec2".to_string(), "vec3".to_string()];
index.add_vectors(vectors, ids)?;
assert_eq!(index.size(), 3);
let query = vec![1.0, 0.1, 0.0, 0.0];
let params = FaissSearchParams {
k: 2,
..Default::default()
};
let results = index.search(&query, ¶ms)?;
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, "vec1"); Ok(())
}
#[test]
fn test_faiss_training() -> Result<()> {
let config = FaissConfig {
dimension: 4,
index_type: FaissIndexType::IvfFlat,
num_clusters: Some(2),
training_sample_size: 10,
..Default::default()
};
let index = FaissIndex::new(config)?;
let training_vectors: Vec<Vec<f32>> = (0..10)
.map(|i| vec![i as f32, (i % 2) as f32, 0.0, 0.0])
.collect();
index.train(&training_vectors)?;
let state = index
.training_state
.read()
.expect("training_state lock not poisoned");
assert!(state.is_trained);
assert_eq!(state.training_progress, 1.0);
Ok(())
}
#[test]
fn test_faiss_factory() -> Result<()> {
let index = FaissFactory::create_optimized_index(64, 1000, false)?;
assert_eq!(index.dimension(), 64);
let gpu_index = FaissFactory::create_gpu_index(128, vec![0])?;
assert_eq!(gpu_index.dimension(), 128);
assert!(gpu_index.config.use_gpu);
Ok(())
}
#[test]
fn test_faiss_auto_index_selection() -> Result<()> {
let config = FaissConfig {
dimension: 64,
index_type: FaissIndexType::Auto,
..Default::default()
};
let index = FaissIndex::new(config)?;
let index_str = index.faiss_index_string()?;
assert_eq!(index_str, "Flat");
Ok(())
}
}