use faiss::{index::IndexImpl, index_factory, Idx, Index, MetricType};
use crate::distance::DistanceMetric;
use crate::error::{Result, SynaError};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct FaissConfig {
pub index_type: String,
pub train_size: usize,
pub nprobe: usize,
pub use_gpu: bool,
}
impl Default for FaissConfig {
fn default() -> Self {
Self {
index_type: "IVF1024,Flat".to_string(),
train_size: 10000,
nprobe: 10,
use_gpu: false,
}
}
}
pub struct FaissIndex {
index: IndexImpl,
config: FaissConfig,
dimensions: u16,
key_to_id: HashMap<String, i64>,
id_to_key: HashMap<i64, String>,
next_id: i64,
is_trained: bool,
training_vectors: Vec<f32>,
}
impl FaissIndex {
pub fn new(dimensions: u16, metric: DistanceMetric, config: FaissConfig) -> Result<Self> {
let metric_type = match metric {
DistanceMetric::Euclidean => MetricType::L2,
DistanceMetric::Cosine | DistanceMetric::DotProduct => MetricType::InnerProduct,
};
let index = index_factory(dimensions as u32, &config.index_type, metric_type)
.map_err(|e| SynaError::IndexError(format!("FAISS index creation failed: {}", e)))?;
Ok(Self {
index,
config,
dimensions,
key_to_id: HashMap::new(),
id_to_key: HashMap::new(),
next_id: 0,
is_trained: false,
training_vectors: Vec::new(),
})
}
pub fn insert(&mut self, key: &str, vector: &[f32]) -> Result<()> {
if vector.len() != self.dimensions as usize {
return Err(SynaError::DimensionMismatch {
expected: self.dimensions,
got: vector.len() as u16,
});
}
if !self.is_trained && self.requires_training() {
self.training_vectors.extend_from_slice(vector);
if self.training_vectors.len() / self.dimensions as usize >= self.config.train_size {
self.train()?;
}
}
let id = self.next_id;
let idx = Idx::new(id as u64);
self.index
.add_with_ids(vector, &[idx])
.map_err(|e| SynaError::IndexError(format!("FAISS insert failed: {}", e)))?;
self.key_to_id.insert(key.to_string(), id);
self.id_to_key.insert(id, key.to_string());
self.next_id += 1;
Ok(())
}
pub fn search(&mut self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
let search_result = self
.index
.search(query, k)
.map_err(|e| SynaError::IndexError(format!("FAISS search failed: {}", e)))?;
let mut results = Vec::with_capacity(k);
for (dist, idx) in search_result
.distances
.iter()
.zip(search_result.labels.iter())
{
if let Some(id_val) = idx.get() {
let id = id_val as i64;
if let Some(key) = self.id_to_key.get(&id) {
results.push((key.clone(), *dist));
}
}
}
Ok(results)
}
fn train(&mut self) -> Result<()> {
self.index
.train(&self.training_vectors)
.map_err(|e| SynaError::IndexError(format!("FAISS training failed: {}", e)))?;
self.is_trained = true;
self.training_vectors.clear();
Ok(())
}
fn requires_training(&self) -> bool {
self.config.index_type.contains("IVF") || self.config.index_type.contains("PQ")
}
pub fn len(&self) -> usize {
self.key_to_id.len()
}
pub fn is_empty(&self) -> bool {
self.key_to_id.is_empty()
}
pub fn dimensions(&self) -> u16 {
self.dimensions
}
}