use common::{DistanceMetric, Vector};
use parking_lot::RwLock;
use rand::seq::SliceRandom;
use std::collections::HashMap;
use crate::pq::{PQConfig, ProductQuantizer};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct IvfPqConfig {
pub n_clusters: usize,
pub n_probe: usize,
pub pq_subquantizers: usize,
pub pq_centroids: usize,
pub ivf_iterations: usize,
pub pq_iterations: usize,
pub metric: DistanceMetric,
}
impl Default for IvfPqConfig {
fn default() -> Self {
Self {
n_clusters: 256,
n_probe: 8,
pq_subquantizers: 8,
pq_centroids: 256,
ivf_iterations: 20,
pq_iterations: 10,
metric: DistanceMetric::Euclidean,
}
}
}
#[derive(Debug, Clone)]
pub struct IvfPqSearchResult {
pub id: String,
pub score: f32,
pub cluster_id: usize,
}
#[derive(Debug, Clone)]
struct PqEntry {
id: String,
codes: Vec<u8>,
}
pub struct IvfPqIndex {
config: IvfPqConfig,
dimension: Option<usize>,
centroids: Vec<Vec<f32>>,
pq: Option<ProductQuantizer>,
inverted_lists: Vec<RwLock<Vec<PqEntry>>>,
trained: bool,
}
impl IvfPqIndex {
pub fn new(config: IvfPqConfig) -> Self {
Self {
config,
dimension: None,
centroids: Vec::new(),
pq: None,
inverted_lists: Vec::new(),
trained: false,
}
}
pub fn is_trained(&self) -> bool {
self.trained
}
pub fn dimension(&self) -> Option<usize> {
self.dimension
}
pub fn stats(&self) -> IvfPqStats {
let mut list_sizes = Vec::with_capacity(self.inverted_lists.len());
let mut total_vectors = 0usize;
for list in &self.inverted_lists {
let size = list.read().len();
list_sizes.push(size);
total_vectors += size;
}
let avg_list_size = if list_sizes.is_empty() {
0.0
} else {
total_vectors as f64 / list_sizes.len() as f64
};
let max_list_size = list_sizes.iter().copied().max().unwrap_or(0);
let min_list_size = list_sizes.iter().copied().min().unwrap_or(0);
let centroid_memory = self.centroids.len() * self.dimension.unwrap_or(0) * 4;
let pq_memory = self
.pq
.as_ref()
.map(|pq| {
pq.config.num_subquantizers
* pq.config.num_centroids
* (self.dimension.unwrap_or(0) / pq.config.num_subquantizers)
* 4
})
.unwrap_or(0);
let codes_memory = total_vectors * self.config.pq_subquantizers;
IvfPqStats {
n_clusters: self.centroids.len(),
total_vectors,
avg_list_size,
max_list_size,
min_list_size,
trained: self.trained,
dimension: self.dimension,
memory_bytes: centroid_memory + pq_memory + codes_memory,
}
}
pub fn train(&mut self, vectors: &[Vector]) -> Result<(), String> {
if vectors.is_empty() {
return Err("Cannot train on empty vector set".to_string());
}
let dim = vectors[0].values.len();
self.dimension = Some(dim);
for v in vectors {
if v.values.len() != dim {
return Err(format!(
"Dimension mismatch: expected {}, got {}",
dim,
v.values.len()
));
}
}
let n_clusters = self.config.n_clusters.min(vectors.len());
self.centroids = self.kmeans_train(vectors, n_clusters)?;
self.inverted_lists = (0..n_clusters).map(|_| RwLock::new(Vec::new())).collect();
let mut residuals = Vec::with_capacity(vectors.len());
for v in vectors {
let (cluster_id, _) = self.find_nearest_centroid(&v.values);
let residual = self.compute_residual(&v.values, cluster_id);
residuals.push(Vector {
id: v.id.clone(),
values: residual,
metadata: None,
ttl_seconds: None,
expires_at: None,
});
}
let pq_config = PQConfig {
num_subquantizers: self.config.pq_subquantizers,
num_centroids: self.config.pq_centroids,
kmeans_iterations: self.config.pq_iterations,
distance_metric: self.config.metric,
};
let mut pq = ProductQuantizer::new(pq_config, dim)?;
pq.train(&residuals)?;
self.pq = Some(pq);
self.trained = true;
Ok(())
}
pub fn add(&self, vectors: &[Vector]) -> Result<usize, String> {
if !self.trained {
return Err("Index must be trained before adding vectors".to_string());
}
let pq = self.pq.as_ref().ok_or("PQ not initialized")?;
let dim = self.dimension.ok_or("Dimension not set")?;
let mut added = 0;
for v in vectors {
if v.values.len() != dim {
continue;
}
let (cluster_id, _) = self.find_nearest_centroid(&v.values);
let residual = self.compute_residual(&v.values, cluster_id);
let codes = pq.encode(&residual)?;
let entry = PqEntry {
id: v.id.clone(),
codes,
};
self.inverted_lists[cluster_id].write().push(entry);
added += 1;
}
Ok(added)
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<IvfPqSearchResult>, String> {
if !self.trained {
return Err("Index must be trained before searching".to_string());
}
let pq = self.pq.as_ref().ok_or("PQ not initialized")?;
let dim = self.dimension.ok_or("Dimension not set")?;
if query.len() != dim {
return Err(format!(
"Query dimension {} doesn't match index dimension {}",
query.len(),
dim
));
}
let n_probe = self.config.n_probe.min(self.centroids.len());
let nearest_clusters = self.find_nearest_centroids(query, n_probe);
let mut candidates: Vec<IvfPqSearchResult> = Vec::new();
for (cluster_id, _) in nearest_clusters {
let query_residual = self.compute_residual(query, cluster_id);
let distance_table = pq.compute_distance_table(&query_residual)?;
let list = self.inverted_lists[cluster_id].read();
for entry in list.iter() {
let score = pq.compute_distance_adc(&distance_table, &entry.codes);
candidates.push(IvfPqSearchResult {
id: entry.id.clone(),
score, cluster_id,
});
}
}
candidates.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
candidates.truncate(k);
Ok(candidates)
}
fn find_nearest_centroid(&self, vector: &[f32]) -> (usize, f32) {
let mut best_idx = 0;
let mut best_dist = f32::MAX;
for (idx, centroid) in self.centroids.iter().enumerate() {
let dist = euclidean_distance(vector, centroid);
if dist < best_dist {
best_dist = dist;
best_idx = idx;
}
}
(best_idx, best_dist)
}
fn find_nearest_centroids(&self, vector: &[f32], n: usize) -> Vec<(usize, f32)> {
let mut distances: Vec<(usize, f32)> = self
.centroids
.iter()
.enumerate()
.map(|(idx, centroid)| (idx, euclidean_distance(vector, centroid)))
.collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
distances.truncate(n);
distances
}
fn compute_residual(&self, vector: &[f32], cluster_id: usize) -> Vec<f32> {
let centroid = &self.centroids[cluster_id];
vector
.iter()
.zip(centroid.iter())
.map(|(v, c)| v - c)
.collect()
}
fn kmeans_train(&self, vectors: &[Vector], k: usize) -> Result<Vec<Vec<f32>>, String> {
if vectors.is_empty() || k == 0 {
return Err("Invalid input for k-means".to_string());
}
let dim = vectors[0].values.len();
let mut rng = rand::thread_rng();
let mut indices: Vec<usize> = (0..vectors.len()).collect();
indices.shuffle(&mut rng);
let mut centroids: Vec<Vec<f32>> = indices
.iter()
.take(k)
.map(|&i| vectors[i].values.clone())
.collect();
while centroids.len() < k {
centroids.push(vec![0.0; dim]);
}
for _ in 0..self.config.ivf_iterations {
let mut assignments: HashMap<usize, Vec<usize>> = HashMap::new();
for cluster_id in 0..k {
assignments.insert(cluster_id, Vec::new());
}
for (vec_idx, v) in vectors.iter().enumerate() {
let mut best_cluster = 0;
let mut best_dist = f32::MAX;
for (cluster_id, centroid) in centroids.iter().enumerate() {
let dist = euclidean_distance(&v.values, centroid);
if dist < best_dist {
best_dist = dist;
best_cluster = cluster_id;
}
}
if let Some(members) = assignments.get_mut(&best_cluster) {
members.push(vec_idx);
}
}
let mut converged = true;
for (cluster_id, member_indices) in &assignments {
if member_indices.is_empty() {
continue;
}
let mut new_centroid = vec![0.0; dim];
for &idx in member_indices {
for (j, val) in vectors[idx].values.iter().enumerate() {
new_centroid[j] += val;
}
}
for val in &mut new_centroid {
*val /= member_indices.len() as f32;
}
let diff = euclidean_distance(¢roids[*cluster_id], &new_centroid);
if diff > 1e-4 {
converged = false;
}
centroids[*cluster_id] = new_centroid;
}
if converged {
break;
}
}
Ok(centroids)
}
}
#[derive(Debug, Clone)]
pub struct IvfPqStats {
pub n_clusters: usize,
pub total_vectors: usize,
pub avg_list_size: f64,
pub max_list_size: usize,
pub min_list_size: usize,
pub trained: bool,
pub dimension: Option<usize>,
pub memory_bytes: usize,
}
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_vectors(n: usize, dim: usize) -> Vec<Vector> {
use rand::Rng;
use rand::SeedableRng;
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
(0..n)
.map(|i| {
let values: Vec<f32> = (0..dim).map(|_| rng.gen_range(-1.0..1.0)).collect();
Vector {
id: format!("v{}", i),
values,
metadata: None,
ttl_seconds: None,
expires_at: None,
}
})
.collect()
}
#[test]
fn test_ivfpq_creation() {
let config = IvfPqConfig::default();
let index = IvfPqIndex::new(config);
assert!(!index.is_trained());
assert_eq!(index.dimension(), None);
}
#[test]
fn test_ivfpq_training() {
let config = IvfPqConfig {
n_clusters: 4,
n_probe: 2,
pq_subquantizers: 2,
pq_centroids: 8,
ivf_iterations: 5,
pq_iterations: 5,
metric: DistanceMetric::Euclidean,
};
let mut index = IvfPqIndex::new(config);
let vectors = create_test_vectors(100, 16);
let result = index.train(&vectors);
assert!(result.is_ok(), "Training failed: {:?}", result.err());
assert!(index.is_trained());
assert_eq!(index.dimension(), Some(16));
}
#[test]
fn test_ivfpq_add_and_search() {
let config = IvfPqConfig {
n_clusters: 4,
n_probe: 4, pq_subquantizers: 2,
pq_centroids: 8,
ivf_iterations: 5,
pq_iterations: 5,
metric: DistanceMetric::Euclidean,
};
let mut index = IvfPqIndex::new(config);
let vectors = create_test_vectors(100, 16);
index.train(&vectors).unwrap();
let added = index.add(&vectors).unwrap();
assert_eq!(added, 100);
let query = &vectors[0].values;
let results = index.search(query, 10).unwrap();
assert!(!results.is_empty(), "Results should not be empty");
let found_self = results.iter().any(|r| r.id == "v0");
assert!(
found_self,
"Query vector should be found in results. Got: {:?}",
results.iter().map(|r| &r.id).collect::<Vec<_>>()
);
}
#[test]
fn test_ivfpq_stats() {
let config = IvfPqConfig {
n_clusters: 4,
n_probe: 2,
pq_subquantizers: 2,
pq_centroids: 8,
ivf_iterations: 5,
pq_iterations: 5,
metric: DistanceMetric::Euclidean,
};
let mut index = IvfPqIndex::new(config);
let vectors = create_test_vectors(100, 16);
index.train(&vectors).unwrap();
index.add(&vectors).unwrap();
let stats = index.stats();
assert_eq!(stats.n_clusters, 4);
assert_eq!(stats.total_vectors, 100);
assert!(stats.trained);
assert_eq!(stats.dimension, Some(16));
assert!(stats.memory_bytes > 0);
}
#[test]
fn test_ivfpq_search_quality() {
let config = IvfPqConfig {
n_clusters: 8,
n_probe: 8, pq_subquantizers: 4,
pq_centroids: 16,
ivf_iterations: 10,
pq_iterations: 10,
metric: DistanceMetric::Euclidean,
};
let mut index = IvfPqIndex::new(config);
let vectors = create_test_vectors(200, 32);
index.train(&vectors).unwrap();
index.add(&vectors).unwrap();
let mut total_recall = 0.0;
let test_queries = 10;
for i in 0..test_queries {
let query = &vectors[i * 10].values;
let results = index.search(query, 20).unwrap();
let expected_id = format!("v{}", i * 10);
if results.iter().any(|r| r.id == expected_id) {
total_recall += 1.0;
}
}
let recall = total_recall / test_queries as f32;
assert!(
recall >= 0.5,
"Recall should be at least 50%, got {}%",
recall * 100.0
);
}
#[test]
fn test_ivfpq_empty_search() {
let config = IvfPqConfig {
n_clusters: 4,
n_probe: 2,
pq_subquantizers: 2,
pq_centroids: 8,
ivf_iterations: 5,
pq_iterations: 5,
metric: DistanceMetric::Euclidean,
};
let mut index = IvfPqIndex::new(config);
let vectors = create_test_vectors(50, 16);
index.train(&vectors).unwrap();
let query = &vectors[0].values;
let results = index.search(query, 5).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_ivfpq_untrained_error() {
let index = IvfPqIndex::new(IvfPqConfig::default());
let result = index.search(&[0.0; 128], 5);
assert!(result.is_err());
assert!(result.unwrap_err().contains("trained"));
}
}