use anyhow::{Context, Result};
use rand::Rng;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::simd;
use crate::types::{DistanceMetric, SearchResult};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IvfPqConfig {
pub nclusters: usize,
pub nsubvectors: usize,
pub nbits: usize,
pub nprobe: usize,
pub metric: DistanceMetric,
pub max_kmeans_iterations: usize,
pub kmeans_tolerance: f32,
}
impl Default for IvfPqConfig {
fn default() -> Self {
Self {
nclusters: 256,
nsubvectors: 64,
nbits: 8,
nprobe: 16,
metric: DistanceMetric::Cosine,
max_kmeans_iterations: 100,
kmeans_tolerance: 1e-4,
}
}
}
impl IvfPqConfig {
pub fn with_nclusters(mut self, nclusters: usize) -> Self {
self.nclusters = nclusters;
self
}
pub fn with_nsubvectors(mut self, nsubvectors: usize) -> Self {
self.nsubvectors = nsubvectors;
self
}
pub fn with_nbits(mut self, nbits: usize) -> Self {
self.nbits = nbits;
self
}
pub fn with_nprobe(mut self, nprobe: usize) -> Self {
self.nprobe = nprobe;
self
}
pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
self.metric = metric;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ProductQuantizer {
nsubvectors: usize,
subvector_dim: usize,
codebooks: Vec<Vec<Vec<f32>>>,
ncentroids: usize,
}
impl ProductQuantizer {
fn new(dim: usize, nsubvectors: usize, nbits: usize) -> Result<Self> {
if !dim.is_multiple_of(nsubvectors) {
anyhow::bail!(
"Vector dimension {} must be divisible by number of sub-vectors {}",
dim,
nsubvectors
);
}
let subvector_dim = dim / nsubvectors;
let ncentroids = 1 << nbits;
Ok(Self {
nsubvectors,
subvector_dim,
codebooks: vec![],
ncentroids,
})
}
fn train(&mut self, vectors: &[Vec<f32>], iterations: usize) -> Result<()> {
self.codebooks.clear();
for subvec_idx in 0..self.nsubvectors {
let start = subvec_idx * self.subvector_dim;
let end = start + self.subvector_dim;
let subvectors: Vec<Vec<f32>> =
vectors.iter().map(|v| v[start..end].to_vec()).collect();
let centroids = kmeans(&subvectors, self.ncentroids, iterations)?;
self.codebooks.push(centroids);
}
Ok(())
}
fn encode(&self, vector: &[f32]) -> Vec<u8> {
let mut codes = Vec::with_capacity(self.nsubvectors);
for subvec_idx in 0..self.nsubvectors {
let start = subvec_idx * self.subvector_dim;
let end = start + self.subvector_dim;
let subvector = &vector[start..end];
let mut best_idx = 0;
let mut best_dist = f32::MAX;
for (centroid_idx, centroid) in self.codebooks[subvec_idx].iter().enumerate() {
let dist = euclidean_distance(subvector, centroid);
if dist < best_dist {
best_dist = dist;
best_idx = centroid_idx;
}
}
codes.push(best_idx as u8);
}
codes
}
fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
let mut total_dist = 0.0;
#[allow(clippy::needless_range_loop)]
for subvec_idx in 0..self.nsubvectors {
let start = subvec_idx * self.subvector_dim;
let end = start + self.subvector_dim;
let query_subvector = &query[start..end];
let code = codes[subvec_idx] as usize;
let centroid = &self.codebooks[subvec_idx][code];
total_dist += euclidean_distance(query_subvector, centroid);
}
total_dist
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IvfPqIndex {
config: IvfPqConfig,
centroids: Vec<Vec<f32>>,
inverted_lists: Vec<Vec<(String, Vec<u8>)>>,
pq: Option<ProductQuantizer>,
dim: Option<usize>,
size: usize,
}
impl IvfPqIndex {
pub fn new(config: IvfPqConfig) -> Self {
Self {
config,
centroids: Vec::new(),
inverted_lists: Vec::new(),
pq: None,
dim: None,
size: 0,
}
}
pub fn build(&mut self, vectors: &HashMap<String, Vec<f32>>) -> Result<()> {
if vectors.is_empty() {
anyhow::bail!("Cannot build index with empty vector collection");
}
let dim = vectors.values().next().unwrap().len();
self.dim = Some(dim);
let vec_list: Vec<Vec<f32>> = vectors.values().cloned().collect();
println!(
"Training coarse quantizer ({} clusters)...",
self.config.nclusters
);
self.centroids = kmeans(
&vec_list,
self.config.nclusters,
self.config.max_kmeans_iterations,
)
.context("Failed to train coarse quantizer")?;
println!(
"Training product quantizer ({} sub-vectors)...",
self.config.nsubvectors
);
let mut pq = ProductQuantizer::new(dim, self.config.nsubvectors, self.config.nbits)?;
pq.train(&vec_list, 50)?; self.pq = Some(pq);
println!("Assigning vectors to clusters and quantizing...");
self.inverted_lists = vec![Vec::new(); self.config.nclusters];
for (entity_id, vector) in vectors {
let cluster_id = self.assign_to_cluster(vector);
let codes = self.pq.as_ref().unwrap().encode(vector);
self.inverted_lists[cluster_id].push((entity_id.clone(), codes));
}
self.size = vectors.len();
println!(
"Index built: {} vectors in {} clusters",
self.size, self.config.nclusters
);
Ok(())
}
fn assign_to_cluster(&self, vector: &[f32]) -> usize {
let mut best_idx = 0;
let mut best_dist = f32::MAX;
for (idx, centroid) in self.centroids.iter().enumerate() {
let dist = compute_distance(&self.config.metric, vector, centroid);
if dist < best_dist {
best_dist = dist;
best_idx = idx;
}
}
best_idx
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
if self.pq.is_none() {
anyhow::bail!("Index not built yet");
}
let mut cluster_distances: Vec<(usize, f32)> = self
.centroids
.iter()
.enumerate()
.map(|(idx, centroid)| (idx, compute_distance(&self.config.metric, query, centroid)))
.collect();
cluster_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let probe_clusters: Vec<usize> = cluster_distances
.iter()
.take(self.config.nprobe.min(self.centroids.len()))
.map(|(idx, _)| *idx)
.collect();
let pq = self.pq.as_ref().unwrap();
let mut candidates = Vec::new();
for cluster_id in probe_clusters {
for (entity_id, codes) in &self.inverted_lists[cluster_id] {
let dist = pq.asymmetric_distance(query, codes);
candidates.push(SearchResult {
entity_id: entity_id.clone(),
score: dist,
distance: dist,
rank: 0, });
}
}
candidates.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap());
let results: Vec<SearchResult> = candidates
.into_iter()
.take(k)
.enumerate()
.map(|(rank, mut r)| {
r.distance = r.score;
r.rank = rank + 1;
r
})
.collect();
Ok(results)
}
pub fn stats(&self) -> IvfPqStats {
let avg_list_size = if self.centroids.is_empty() {
0.0
} else {
self.size as f32 / self.centroids.len() as f32
};
let memory_bytes = self.estimate_memory();
IvfPqStats {
nclusters: self.centroids.len(),
nvectors: self.size,
dimension: self.dim.unwrap_or(0),
avg_list_size,
memory_bytes,
compression_ratio: self.compression_ratio(),
}
}
fn estimate_memory(&self) -> usize {
let centroids_mem = self.centroids.len() * self.dim.unwrap_or(0) * 4;
let inverted_mem = self.size * self.config.nsubvectors;
let pq_mem = if let Some(pq) = &self.pq {
pq.nsubvectors * pq.ncentroids * pq.subvector_dim * 4
} else {
0
};
centroids_mem + inverted_mem + pq_mem
}
fn compression_ratio(&self) -> f32 {
if self.size == 0 || self.dim.is_none() {
return 0.0;
}
let original_size = self.size * self.dim.unwrap() * 4; let compressed_size = self.estimate_memory();
original_size as f32 / compressed_size as f32
}
}
#[derive(Debug, Clone)]
pub struct IvfPqStats {
pub nclusters: usize,
pub nvectors: usize,
pub dimension: usize,
pub avg_list_size: f32,
pub memory_bytes: usize,
pub compression_ratio: f32,
}
fn kmeans(vectors: &[Vec<f32>], k: usize, max_iterations: usize) -> Result<Vec<Vec<f32>>> {
if vectors.is_empty() {
anyhow::bail!("Cannot run k-means on empty vector set");
}
let dim = vectors[0].len();
let n = vectors.len();
if k > n {
anyhow::bail!("Number of clusters {} exceeds number of vectors {}", k, n);
}
let mut rng = rand::rng();
let mut centroids = Vec::with_capacity(k);
let first_idx = rng.random_range(0..n);
centroids.push(vectors[first_idx].clone());
for _ in 1..k {
let distances: Vec<f32> = vectors
.iter()
.map(|v| {
centroids
.iter()
.map(|c| euclidean_distance(v, c))
.fold(f32::MAX, f32::min)
})
.collect();
let total: f32 = distances.iter().map(|d| d * d).sum();
let mut threshold = rng.random_range(0.0..total);
for (idx, &dist) in distances.iter().enumerate() {
threshold -= dist * dist;
if threshold <= 0.0 {
centroids.push(vectors[idx].clone());
break;
}
}
}
for _iter in 0..max_iterations {
let assignments: Vec<usize> = vectors
.par_iter()
.map(|v| {
centroids
.iter()
.enumerate()
.map(|(idx, c)| (idx, euclidean_distance(v, c)))
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.unwrap()
.0
})
.collect();
let mut new_centroids = vec![vec![0.0; dim]; k];
let mut counts = vec![0; k];
for (vec, &cluster_id) in vectors.iter().zip(&assignments) {
for (i, &val) in vec.iter().enumerate() {
new_centroids[cluster_id][i] += val;
}
counts[cluster_id] += 1;
}
for (centroid, count) in new_centroids.iter_mut().zip(&counts) {
if *count > 0 {
for val in centroid.iter_mut() {
*val /= *count as f32;
}
}
}
let mut total_movement = 0.0;
for (old, new) in centroids.iter().zip(&new_centroids) {
total_movement += euclidean_distance(old, new);
}
centroids = new_centroids;
if total_movement < 0.001 {
break;
}
}
Ok(centroids)
}
#[inline]
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
simd::euclidean_distance_simd(a, b)
}
#[inline]
fn compute_distance(metric: &DistanceMetric, a: &[f32], b: &[f32]) -> f32 {
simd::compute_distance_lower_is_better_simd(*metric, a, b)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ivf_pq_creation() {
let config = IvfPqConfig::default()
.with_nclusters(16)
.with_nsubvectors(8);
let index = IvfPqIndex::new(config);
assert_eq!(index.config.nclusters, 16);
assert_eq!(index.config.nsubvectors, 8);
}
#[test]
fn test_product_quantizer() {
let dim = 64;
let nsubvectors = 8;
let nbits = 8;
let pq = ProductQuantizer::new(dim, nsubvectors, nbits);
assert!(pq.is_ok());
let pq = pq.unwrap();
assert_eq!(pq.subvector_dim, 8);
assert_eq!(pq.ncentroids, 256);
}
#[test]
fn test_kmeans_basic() {
let vectors = vec![
vec![1.0, 0.0],
vec![1.1, 0.1],
vec![0.0, 1.0],
vec![0.1, 1.1],
];
let centroids = kmeans(&vectors, 2, 10);
assert!(centroids.is_ok());
let centroids = centroids.unwrap();
assert_eq!(centroids.len(), 2);
}
#[test]
fn test_ivf_pq_build_and_search() {
let mut vectors = HashMap::new();
for i in 0..300 {
let vec: Vec<f32> = (0..64).map(|j| (i + j) as f32 / 300.0).collect();
vectors.insert(format!("doc{}", i), vec);
}
let config = IvfPqConfig::default()
.with_nclusters(8)
.with_nsubvectors(8)
.with_nbits(4) .with_nprobe(2);
let mut index = IvfPqIndex::new(config);
let build_result = index.build(&vectors);
if let Err(e) = &build_result {
panic!("Build failed: {}", e);
}
let query = vectors.get("doc150").unwrap().clone();
let results = index.search(&query, 5);
assert!(results.is_ok());
let results = results.unwrap();
assert_eq!(results.len(), 5);
assert!(results[0].entity_id.starts_with("doc"));
}
#[test]
fn test_ivf_pq_nprobe_effect() {
let mut vectors = HashMap::new();
for i in 0..300 {
let vec: Vec<f32> = (0..64).map(|j| (i + j) as f32 / 300.0).collect();
vectors.insert(format!("doc{}", i), vec);
}
let config1 = IvfPqConfig::default()
.with_nclusters(4)
.with_nsubvectors(8)
.with_nbits(4) .with_nprobe(1);
let mut index1 = IvfPqIndex::new(config1);
assert!(index1.build(&vectors).is_ok());
let config2 = IvfPqConfig::default()
.with_nclusters(4)
.with_nsubvectors(8)
.with_nbits(4) .with_nprobe(4);
let mut index2 = IvfPqIndex::new(config2);
assert!(index2.build(&vectors).is_ok());
let query = vectors.get("doc150").unwrap().clone();
let results1 = index1.search(&query, 5).unwrap();
let results2 = index2.search(&query, 5).unwrap();
assert_eq!(results1.len(), 5);
assert_eq!(results2.len(), 5);
assert!(results1[0].score >= 0.0);
assert!(results2[0].score >= 0.0);
}
#[test]
fn test_ivf_pq_stats() {
let mut vectors = HashMap::new();
for i in 0..300 {
let vec: Vec<f32> = (0..128).map(|j| (i + j) as f32 / 300.0).collect();
vectors.insert(format!("doc{}", i), vec);
}
let config = IvfPqConfig::default()
.with_nclusters(10)
.with_nsubvectors(16)
.with_nbits(4);
let mut index = IvfPqIndex::new(config);
assert!(index.build(&vectors).is_ok());
let stats = index.stats();
assert_eq!(stats.nclusters, 10);
assert_eq!(stats.nvectors, 300);
assert_eq!(stats.dimension, 128);
assert!(stats.avg_list_size > 0.0);
assert!(stats.memory_bytes > 0);
assert!(stats.compression_ratio > 1.0); }
#[test]
fn test_ivf_pq_compression_ratio() {
let mut vectors = HashMap::new();
for i in 0..200 {
let vec: Vec<f32> = (0..128).map(|j| (i + j) as f32 / 200.0).collect();
vectors.insert(format!("doc{}", i), vec);
}
let config = IvfPqConfig {
nclusters: 8,
nsubvectors: 8, nbits: 4, max_kmeans_iterations: 20, ..IvfPqConfig::default()
};
let mut index = IvfPqIndex::new(config);
assert!(index.build(&vectors).is_ok());
let stats = index.stats();
let original_size = 200 * 128 * 4;
assert!(stats.memory_bytes < original_size);
assert!(stats.compression_ratio > 1.0);
println!(
"Compression: {:.2}x (original: {} bytes, compressed: {} bytes)",
stats.compression_ratio, original_size, stats.memory_bytes
);
}
#[test]
#[ignore]
fn test_ivf_pq_compression_ratio_full() {
let mut vectors = HashMap::new();
for i in 0..500 {
let vec: Vec<f32> = (0..768).map(|j| (i + j) as f32 / 500.0).collect();
vectors.insert(format!("doc{}", i), vec);
}
let config = IvfPqConfig::default()
.with_nclusters(16)
.with_nsubvectors(64)
.with_nbits(6);
let mut index = IvfPqIndex::new(config);
assert!(index.build(&vectors).is_ok());
let stats = index.stats();
let original_size = 500 * 768 * 4;
assert!(stats.memory_bytes < original_size);
assert!(stats.compression_ratio > 1.0);
println!(
"Compression: {:.2}x (original: {} bytes, compressed: {} bytes)",
stats.compression_ratio, original_size, stats.memory_bytes
);
}
#[test]
fn test_ivf_pq_empty_vectors_error() {
let vectors = HashMap::new();
let config = IvfPqConfig::default();
let mut index = IvfPqIndex::new(config);
let result = index.build(&vectors);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Cannot build index with empty vector collection"));
}
#[test]
fn test_ivf_pq_search_before_build_error() {
let config = IvfPqConfig::default();
let index = IvfPqIndex::new(config);
let query = vec![0.1; 64];
let result = index.search(&query, 10);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Index not built"));
}
#[test]
fn test_ivf_pq_invalid_dimension_error() {
let _config = IvfPqConfig::default().with_nsubvectors(8);
let pq = ProductQuantizer::new(65, 8, 8);
assert!(pq.is_err());
assert!(pq.unwrap_err().to_string().contains("must be divisible by"));
}
#[test]
fn test_ivf_pq_different_metrics() {
let mut vectors = HashMap::new();
for i in 0..300 {
let vec: Vec<f32> = (0..64).map(|j| (i + j) as f32 / 300.0).collect();
vectors.insert(format!("doc{}", i), vec);
}
let query = vectors.get("doc150").unwrap().clone();
let metrics = vec![
DistanceMetric::Cosine,
DistanceMetric::Euclidean,
DistanceMetric::DotProduct,
DistanceMetric::Manhattan,
];
for metric in metrics {
let config = IvfPqConfig::default()
.with_nclusters(4)
.with_nsubvectors(8)
.with_nbits(4) .with_metric(metric);
let mut index = IvfPqIndex::new(config);
assert!(index.build(&vectors).is_ok());
let results = index.search(&query, 3);
assert!(results.is_ok());
let results = results.unwrap();
assert_eq!(results.len(), 3);
}
}
#[test]
fn test_product_quantizer_encode_decode() {
let dim = 64;
let nsubvectors = 8;
let nbits = 4;
let mut pq = ProductQuantizer::new(dim, nsubvectors, nbits).unwrap();
let mut train_vectors = Vec::new();
for i in 0..100 {
let vec: Vec<f32> = (0..dim).map(|j| (i + j) as f32 / 100.0).collect();
train_vectors.push(vec);
}
let train_result = pq.train(&train_vectors, 20);
if let Err(e) = &train_result {
panic!("PQ training failed: {}", e);
}
let test_vector: Vec<f32> = (0..dim).map(|i| i as f32 / 64.0).collect();
let codes = pq.encode(&test_vector);
assert_eq!(codes.len(), nsubvectors);
for &code in &codes {
assert!((code as usize) < pq.ncentroids);
}
let distance = pq.asymmetric_distance(&test_vector, &codes);
assert!(distance >= 0.0);
}
#[test]
fn test_kmeans_convergence() {
let mut vectors = Vec::new();
for i in 0..20 {
vectors.push(vec![1.0 + (i as f32) * 0.01, 1.0 + (i as f32) * 0.01]);
}
for i in 0..20 {
vectors.push(vec![10.0 + (i as f32) * 0.01, 10.0 + (i as f32) * 0.01]);
}
let centroids = kmeans(&vectors, 2, 50).unwrap();
assert_eq!(centroids.len(), 2);
let mut has_low_centroid = false;
let mut has_high_centroid = false;
for centroid in ¢roids {
if centroid[0] < 5.0 {
has_low_centroid = true;
assert!(centroid[0] > 0.5 && centroid[0] < 1.5);
} else {
has_high_centroid = true;
assert!(centroid[0] > 9.5 && centroid[0] < 10.5);
}
}
assert!(has_low_centroid);
assert!(has_high_centroid);
}
#[test]
fn test_kmeans_error_cases() {
let empty_vectors: Vec<Vec<f32>> = vec![];
let result = kmeans(&empty_vectors, 2, 10);
assert!(result.is_err());
let vectors = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let result = kmeans(&vectors, 5, 10);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("exceeds"));
}
}