use common::DistanceMetric;
use parking_lot::RwLock;
use rand::Rng;
use std::collections::HashMap;
use crate::distance::calculate_distance;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct IvfConfig {
pub n_clusters: usize,
pub n_probe: usize,
pub max_iterations: usize,
pub convergence_threshold: f32,
pub metric: DistanceMetric,
}
impl Default for IvfConfig {
fn default() -> Self {
Self {
n_clusters: 256,
n_probe: 10,
max_iterations: 100,
convergence_threshold: 1e-4,
metric: DistanceMetric::Cosine,
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct IndexedVector {
pub id: String,
pub values: Vec<f32>,
}
pub struct IvfIndex {
config: IvfConfig,
dimension: Option<usize>,
centroids: RwLock<Vec<Vec<f32>>>,
inverted_lists: RwLock<HashMap<usize, Vec<IndexedVector>>>,
vector_count: RwLock<usize>,
is_trained: RwLock<bool>,
}
impl IvfIndex {
pub fn new(config: IvfConfig) -> Self {
Self {
config,
dimension: None,
centroids: RwLock::new(Vec::new()),
inverted_lists: RwLock::new(HashMap::new()),
vector_count: RwLock::new(0),
is_trained: RwLock::new(false),
}
}
pub fn with_defaults() -> Self {
Self::new(IvfConfig::default())
}
pub fn train(&mut self, vectors: &[Vec<f32>]) -> Result<(), String> {
if vectors.is_empty() {
return Err("Cannot train on empty vector set".to_string());
}
let dim = vectors[0].len();
if dim == 0 {
return Err("Vector dimension cannot be zero".to_string());
}
for v in vectors {
if v.len() != dim {
return Err(format!(
"Dimension mismatch: expected {}, got {}",
dim,
v.len()
));
}
}
self.dimension = Some(dim);
let n_clusters = self.config.n_clusters.min(vectors.len());
let centroids = self.kmeans(vectors, n_clusters)?;
*self.centroids.write() = centroids;
*self.is_trained.write() = true;
let mut lists = self.inverted_lists.write();
lists.clear();
for i in 0..n_clusters {
lists.insert(i, Vec::new());
}
tracing::info!(
n_clusters = n_clusters,
dimension = dim,
training_vectors = vectors.len(),
"IVF index trained"
);
Ok(())
}
fn kmeans(&self, vectors: &[Vec<f32>], k: usize) -> Result<Vec<Vec<f32>>, String> {
let dim = vectors[0].len();
let mut rng = rand::thread_rng();
let mut centroids = self.kmeans_plus_plus_init(vectors, k, &mut rng);
for iteration in 0..self.config.max_iterations {
let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); k];
for (idx, vector) in vectors.iter().enumerate() {
let nearest = self.find_nearest_centroid(vector, ¢roids);
assignments[nearest].push(idx);
}
let mut new_centroids = Vec::with_capacity(k);
let mut max_shift = 0.0f32;
for (cluster_idx, indices) in assignments.iter().enumerate() {
if indices.is_empty() {
new_centroids.push(centroids[cluster_idx].clone());
continue;
}
let mut new_centroid = vec![0.0f32; dim];
for &idx in indices {
for (j, val) in vectors[idx].iter().enumerate() {
new_centroid[j] += val;
}
}
for val in &mut new_centroid {
*val /= indices.len() as f32;
}
let shift = euclidean_distance(¢roids[cluster_idx], &new_centroid);
max_shift = max_shift.max(shift);
new_centroids.push(new_centroid);
}
centroids = new_centroids;
if max_shift < self.config.convergence_threshold {
tracing::debug!(
iteration = iteration,
max_shift = max_shift,
"K-means converged"
);
break;
}
}
Ok(centroids)
}
fn kmeans_plus_plus_init<R: Rng>(
&self,
vectors: &[Vec<f32>],
k: usize,
rng: &mut R,
) -> Vec<Vec<f32>> {
let mut centroids = Vec::with_capacity(k);
let first_idx = rng.gen_range(0..vectors.len());
centroids.push(vectors[first_idx].clone());
for _ in 1..k {
let mut distances: Vec<f32> = vectors
.iter()
.map(|v| {
centroids
.iter()
.map(|c| euclidean_distance(v, c))
.fold(f32::MAX, f32::min)
.powi(2)
})
.collect();
let total: f32 = distances.iter().sum();
if total == 0.0 {
break;
}
for d in &mut distances {
*d /= total;
}
let sample: f32 = rng.gen();
let mut cumsum = 0.0;
let mut selected = 0;
for (i, &d) in distances.iter().enumerate() {
cumsum += d;
if cumsum >= sample {
selected = i;
break;
}
}
centroids.push(vectors[selected].clone());
}
centroids
}
fn find_nearest_centroid(&self, vector: &[f32], centroids: &[Vec<f32>]) -> usize {
let mut best_idx = 0;
let mut best_score = f32::NEG_INFINITY;
for (idx, centroid) in centroids.iter().enumerate() {
let score = calculate_distance(vector, centroid, self.config.metric);
if score > best_score {
best_score = score;
best_idx = idx;
}
}
best_idx
}
fn find_nearest_centroids(&self, vector: &[f32], n: usize) -> Vec<usize> {
let centroids = self.centroids.read();
let mut scores: Vec<(usize, f32)> = centroids
.iter()
.enumerate()
.map(|(idx, c)| (idx, calculate_distance(vector, c, self.config.metric)))
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.into_iter().take(n).map(|(idx, _)| idx).collect()
}
pub fn add(&self, id: String, vector: Vec<f32>) -> Result<(), String> {
if !*self.is_trained.read() {
return Err("Index must be trained before adding vectors".to_string());
}
if let Some(dim) = self.dimension {
if vector.len() != dim {
return Err(format!(
"Dimension mismatch: expected {}, got {}",
dim,
vector.len()
));
}
}
let centroids = self.centroids.read();
let cluster_idx = self.find_nearest_centroid(&vector, ¢roids);
drop(centroids);
let indexed = IndexedVector { id, values: vector };
let mut lists = self.inverted_lists.write();
lists.entry(cluster_idx).or_default().push(indexed);
drop(lists);
*self.vector_count.write() += 1;
Ok(())
}
pub fn add_batch(&self, vectors: Vec<(String, Vec<f32>)>) -> Result<usize, String> {
let mut count = 0;
for (id, vector) in vectors {
self.add(id, vector)?;
count += 1;
}
Ok(count)
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>, String> {
if !*self.is_trained.read() {
return Err("Index must be trained before searching".to_string());
}
if let Some(dim) = self.dimension {
if query.len() != dim {
return Err(format!(
"Dimension mismatch: expected {}, got {}",
dim,
query.len()
));
}
}
let n_probe = self.config.n_probe.min(self.centroids.read().len());
let probe_clusters = self.find_nearest_centroids(query, n_probe);
let mut candidates: Vec<SearchResult> = Vec::new();
let lists = self.inverted_lists.read();
for cluster_idx in probe_clusters {
if let Some(vectors) = lists.get(&cluster_idx) {
for indexed in vectors {
let score = calculate_distance(query, &indexed.values, self.config.metric);
candidates.push(SearchResult {
id: indexed.id.clone(),
score,
});
}
}
}
candidates.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
candidates.truncate(k);
Ok(candidates)
}
pub fn remove(&self, id: &str) -> bool {
let mut lists = self.inverted_lists.write();
let mut removed = false;
for vectors in lists.values_mut() {
if let Some(pos) = vectors.iter().position(|v| v.id == id) {
vectors.remove(pos);
removed = true;
break;
}
}
if removed {
*self.vector_count.write() -= 1;
}
removed
}
pub fn len(&self) -> usize {
*self.vector_count.read()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn is_trained(&self) -> bool {
*self.is_trained.read()
}
pub fn n_clusters(&self) -> usize {
self.centroids.read().len()
}
pub fn config(&self) -> &IvfConfig {
&self.config
}
pub fn dimension(&self) -> Option<usize> {
self.dimension
}
pub(crate) fn centroids_read(&self) -> Vec<Vec<f32>> {
self.centroids.read().clone()
}
pub(crate) fn inverted_lists_read(&self) -> HashMap<usize, Vec<IndexedVector>> {
self.inverted_lists.read().clone()
}
pub fn from_snapshot(snapshot: crate::persistence::IvfFullSnapshot) -> Result<Self, String> {
let mut inverted_lists = HashMap::new();
for (cluster_id, vectors) in snapshot.inverted_lists {
inverted_lists.insert(cluster_id, vectors);
}
Ok(Self {
config: snapshot.config,
dimension: snapshot.dimension,
centroids: RwLock::new(snapshot.centroids),
inverted_lists: RwLock::new(inverted_lists),
vector_count: RwLock::new(snapshot.vector_count),
is_trained: RwLock::new(snapshot.is_trained),
})
}
}
#[derive(Debug, Clone)]
pub struct SearchResult {
pub id: String,
pub score: f32,
}
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
fn generate_random_vectors(n: usize, dim: usize) -> Vec<Vec<f32>> {
let mut rng = rand::thread_rng();
(0..n)
.map(|_| (0..dim).map(|_| rng.gen::<f32>()).collect())
.collect()
}
#[test]
fn test_ivf_train() {
let vectors = generate_random_vectors(100, 32);
let mut index = IvfIndex::new(IvfConfig {
n_clusters: 10,
..Default::default()
});
index.train(&vectors).unwrap();
assert!(index.is_trained());
assert_eq!(index.n_clusters(), 10);
}
#[test]
fn test_ivf_add_and_search() {
let training_vectors = generate_random_vectors(100, 32);
let mut index = IvfIndex::new(IvfConfig {
n_clusters: 10,
n_probe: 3,
..Default::default()
});
index.train(&training_vectors).unwrap();
for (i, v) in training_vectors.iter().enumerate() {
index.add(format!("vec_{}", i), v.clone()).unwrap();
}
assert_eq!(index.len(), 100);
let query = &training_vectors[0];
let results = index.search(query, 5).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 5);
assert_eq!(results[0].id, "vec_0");
}
#[test]
fn test_ivf_remove() {
let vectors = generate_random_vectors(50, 16);
let mut index = IvfIndex::new(IvfConfig {
n_clusters: 5,
..Default::default()
});
index.train(&vectors).unwrap();
for (i, v) in vectors.iter().enumerate() {
index.add(format!("vec_{}", i), v.clone()).unwrap();
}
assert_eq!(index.len(), 50);
let removed = index.remove("vec_10");
assert!(removed);
assert_eq!(index.len(), 49);
let not_removed = index.remove("nonexistent");
assert!(!not_removed);
}
#[test]
fn test_ivf_dimension_mismatch() {
let vectors = generate_random_vectors(50, 16);
let mut index = IvfIndex::new(IvfConfig {
n_clusters: 5,
..Default::default()
});
index.train(&vectors).unwrap();
index.add("test".to_string(), vectors[0].clone()).unwrap();
let wrong_dim = vec![0.0; 32];
let result = index.add("wrong".to_string(), wrong_dim);
assert!(result.is_err());
}
#[test]
fn test_ivf_untrained_error() {
let index = IvfIndex::with_defaults();
let result = index.add("test".to_string(), vec![0.0; 32]);
assert!(result.is_err());
let result = index.search(&[0.0; 32], 5);
assert!(result.is_err());
}
#[test]
fn test_kmeans_convergence() {
let mut vectors = Vec::new();
let mut rng = rand::thread_rng();
for _ in 0..30 {
vectors.push(vec![1.0 + rng.gen::<f32>() * 0.1, rng.gen::<f32>() * 0.1]);
}
for _ in 0..30 {
vectors.push(vec![rng.gen::<f32>() * 0.1, 1.0 + rng.gen::<f32>() * 0.1]);
}
let mut index = IvfIndex::new(IvfConfig {
n_clusters: 2,
max_iterations: 50,
convergence_threshold: 1e-4,
metric: DistanceMetric::Euclidean,
..Default::default()
});
index.train(&vectors).unwrap();
let centroids = index.centroids.read();
assert_eq!(centroids.len(), 2);
let c1 = ¢roids[0];
let c2 = ¢roids[1];
let dist = euclidean_distance(c1, c2);
assert!(
dist > 0.5,
"Centroids should be well separated, got dist={}",
dist
);
}
fn brute_force_knn(
query: &[f32],
vectors: &[(String, Vec<f32>)],
k: usize,
metric: DistanceMetric,
) -> Vec<String> {
let mut distances: Vec<(String, f32)> = vectors
.iter()
.map(|(id, v)| (id.clone(), calculate_distance(query, v, metric)))
.collect();
distances.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
distances.into_iter().take(k).map(|(id, _)| id).collect()
}
fn calculate_recall(predicted: &[String], actual: &[String]) -> f32 {
let predicted_set: std::collections::HashSet<_> = predicted.iter().collect();
let found = actual
.iter()
.filter(|id| predicted_set.contains(id))
.count();
found as f32 / actual.len() as f32
}
#[test]
fn test_ivf_recall_at_k() {
let n_vectors = 500;
let dim = 64;
let n_clusters = 20;
let k = 10;
let vectors = generate_random_vectors(n_vectors, dim);
let mut index = IvfIndex::new(IvfConfig {
n_clusters,
n_probe: 5, metric: DistanceMetric::Euclidean,
..Default::default()
});
index.train(&vectors).unwrap();
let indexed: Vec<(String, Vec<f32>)> = vectors
.iter()
.enumerate()
.map(|(i, v)| (format!("vec_{}", i), v.clone()))
.collect();
for (id, v) in &indexed {
index.add(id.clone(), v.clone()).unwrap();
}
let n_queries = 20;
let mut total_recall = 0.0;
for q_idx in 0..n_queries {
let query = &vectors[q_idx * (n_vectors / n_queries)];
let ivf_results = index.search(query, k).unwrap();
let ivf_ids: Vec<String> = ivf_results.iter().map(|r| r.id.clone()).collect();
let exact_ids = brute_force_knn(query, &indexed, k, DistanceMetric::Euclidean);
let recall = calculate_recall(&ivf_ids, &exact_ids);
total_recall += recall;
}
let avg_recall = total_recall / n_queries as f32;
assert!(
avg_recall > 0.5,
"Average recall@{} should be > 0.5, got {}",
k,
avg_recall
);
}
#[test]
fn test_ivf_nprobe_effect_on_recall() {
let n_vectors = 300;
let dim = 32;
let n_clusters = 15;
let k = 5;
let vectors = generate_random_vectors(n_vectors, dim);
let mut index_low = IvfIndex::new(IvfConfig {
n_clusters,
n_probe: 2, metric: DistanceMetric::Euclidean,
..Default::default()
});
index_low.train(&vectors).unwrap();
let indexed: Vec<(String, Vec<f32>)> = vectors
.iter()
.enumerate()
.map(|(i, v)| (format!("vec_{}", i), v.clone()))
.collect();
for (id, v) in &indexed {
index_low.add(id.clone(), v.clone()).unwrap();
}
let mut index_high = IvfIndex::new(IvfConfig {
n_clusters,
n_probe: 10, metric: DistanceMetric::Euclidean,
..Default::default()
});
index_high.train(&vectors).unwrap();
for (id, v) in &indexed {
index_high.add(id.clone(), v.clone()).unwrap();
}
let n_queries = 10;
let mut recall_low = 0.0;
let mut recall_high = 0.0;
for q_idx in 0..n_queries {
let query = &vectors[q_idx * (n_vectors / n_queries)];
let low_results = index_low.search(query, k).unwrap();
let low_ids: Vec<String> = low_results.iter().map(|r| r.id.clone()).collect();
let high_results = index_high.search(query, k).unwrap();
let high_ids: Vec<String> = high_results.iter().map(|r| r.id.clone()).collect();
let exact_ids = brute_force_knn(query, &indexed, k, DistanceMetric::Euclidean);
recall_low += calculate_recall(&low_ids, &exact_ids);
recall_high += calculate_recall(&high_ids, &exact_ids);
}
let avg_recall_low = recall_low / n_queries as f32;
let avg_recall_high = recall_high / n_queries as f32;
assert!(
avg_recall_high >= avg_recall_low,
"Higher nprobe should give equal or better recall: low={}, high={}",
avg_recall_low,
avg_recall_high
);
}
#[test]
fn test_ivf_cluster_distribution() {
let n_vectors = 200;
let dim = 16;
let n_clusters = 10;
let vectors = generate_random_vectors(n_vectors, dim);
let mut index = IvfIndex::new(IvfConfig {
n_clusters,
n_probe: 3,
metric: DistanceMetric::Euclidean,
..Default::default()
});
index.train(&vectors).unwrap();
for (i, v) in vectors.iter().enumerate() {
index.add(format!("vec_{}", i), v.clone()).unwrap();
}
let lists = index.inverted_lists.read();
let cluster_sizes: Vec<usize> = lists.values().map(|v| v.len()).collect();
let non_empty_clusters = cluster_sizes.iter().filter(|&&s| s > 0).count();
assert!(
non_empty_clusters >= n_clusters / 2,
"At least half of clusters should be used: {} out of {}",
non_empty_clusters,
n_clusters
);
let max_cluster_size = cluster_sizes.iter().max().copied().unwrap_or(0);
assert!(
max_cluster_size < n_vectors * 3 / 4,
"No cluster should have more than 75% of vectors: {} out of {}",
max_cluster_size,
n_vectors
);
}
#[test]
fn test_ivf_high_dimensional_accuracy() {
let n_vectors = 200;
let dim = 128;
let n_clusters = 16;
let k = 5;
let vectors = generate_random_vectors(n_vectors, dim);
let mut index = IvfIndex::new(IvfConfig {
n_clusters,
n_probe: 4,
metric: DistanceMetric::Cosine, ..Default::default()
});
index.train(&vectors).unwrap();
let indexed: Vec<(String, Vec<f32>)> = vectors
.iter()
.enumerate()
.map(|(i, v)| (format!("vec_{}", i), v.clone()))
.collect();
for (id, v) in &indexed {
index.add(id.clone(), v.clone()).unwrap();
}
let query = &vectors[0];
let results = index.search(query, k).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= k);
assert_eq!(results[0].id, "vec_0");
for result in &results {
assert!(
result.score.is_finite(),
"Score should be finite, got {}",
result.score
);
}
}
#[test]
fn test_ivf_cosine_vs_euclidean() {
let vectors = vec![
vec![1.0, 0.0, 0.0],
vec![0.9, 0.1, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
vec![0.5, 0.5, 0.0],
];
let mut index_cosine = IvfIndex::new(IvfConfig {
n_clusters: 2,
n_probe: 2,
metric: DistanceMetric::Cosine,
..Default::default()
});
index_cosine.train(&vectors).unwrap();
for (i, v) in vectors.iter().enumerate() {
index_cosine.add(format!("vec_{}", i), v.clone()).unwrap();
}
let query = vec![0.95, 0.05, 0.0];
let results_cosine = index_cosine.search(&query, 3).unwrap();
assert_eq!(results_cosine.len(), 3);
let mut index_euclidean = IvfIndex::new(IvfConfig {
n_clusters: 2,
n_probe: 2,
metric: DistanceMetric::Euclidean,
..Default::default()
});
index_euclidean.train(&vectors).unwrap();
for (i, v) in vectors.iter().enumerate() {
index_euclidean
.add(format!("vec_{}", i), v.clone())
.unwrap();
}
let results_euclidean = index_euclidean.search(&query, 3).unwrap();
assert_eq!(results_euclidean.len(), 3);
let top_cosine = &results_cosine[0].id;
let top_euclidean = &results_euclidean[0].id;
assert!(
top_cosine == "vec_0" || top_cosine == "vec_1",
"Cosine top result should be vec_0 or vec_1, got {}",
top_cosine
);
assert!(
top_euclidean == "vec_0" || top_euclidean == "vec_1",
"Euclidean top result should be vec_0 or vec_1, got {}",
top_euclidean
);
}
#[test]
fn test_ivf_batch_accuracy() {
let n_vectors = 100;
let dim = 32;
let vectors = generate_random_vectors(n_vectors, dim);
let mut index = IvfIndex::new(IvfConfig {
n_clusters: 10,
n_probe: 5,
metric: DistanceMetric::Euclidean,
..Default::default()
});
index.train(&vectors).unwrap();
let batch: Vec<(String, Vec<f32>)> = vectors
.iter()
.enumerate()
.map(|(i, v)| (format!("vec_{}", i), v.clone()))
.collect();
let added = index.add_batch(batch.clone()).unwrap();
assert_eq!(added, n_vectors);
assert_eq!(index.len(), n_vectors);
let query = &vectors[0];
let results = index.search(query, 5).unwrap();
assert!(!results.is_empty());
assert!(
results.iter().any(|r| r.id == "vec_0"),
"Query vector should be in search results"
);
}
#[test]
fn test_ivf_empty_cluster_handling() {
let vectors = vec![vec![1.0, 0.0], vec![0.9, 0.1], vec![0.0, 1.0]];
let mut index = IvfIndex::new(IvfConfig {
n_clusters: 3, n_probe: 3,
metric: DistanceMetric::Euclidean,
..Default::default()
});
index.train(&vectors).unwrap();
for (i, v) in vectors.iter().enumerate() {
index.add(format!("vec_{}", i), v.clone()).unwrap();
}
let results = index.search(&vec![0.5, 0.5], 2).unwrap();
assert!(!results.is_empty());
}
}