use crate::{
pq::{PQConfig, PQIndex},
Vector, VectorIndex,
};
use anyhow::{anyhow, Result};
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone, PartialEq)]
pub enum QuantizationStrategy {
None,
ProductQuantization(PQConfig),
ResidualQuantization {
levels: usize,
pq_configs: Vec<PQConfig>,
},
MultiCodebook {
num_codebooks: usize,
pq_configs: Vec<PQConfig>,
},
}
#[derive(Debug, Clone)]
pub struct IvfConfig {
pub n_clusters: usize,
pub n_probes: usize,
pub max_iterations: usize,
pub convergence_threshold: f32,
pub seed: Option<u64>,
pub quantization: QuantizationStrategy,
pub enable_residual_quantization: bool,
pub pq_config: Option<PQConfig>,
}
impl Default for IvfConfig {
fn default() -> Self {
Self {
n_clusters: 256,
n_probes: 8,
max_iterations: 100,
convergence_threshold: 1e-4,
seed: None,
quantization: QuantizationStrategy::None,
enable_residual_quantization: false,
pq_config: None,
}
}
}
#[derive(Debug, Clone)]
enum VectorStorage {
Full(Vector),
Quantized(Vec<u8>),
MultiLevelQuantized {
levels: Vec<Vec<u8>>, final_residual: Option<Vector>, },
MultiCodebook {
codebooks: Vec<Vec<u8>>, weights: Vec<f32>, },
}
#[derive(Debug, Clone)]
struct InvertedList {
vectors: Vec<(String, VectorStorage)>,
quantization: QuantizationStrategy,
pq_index: Option<PQIndex>,
multi_level_pq: Vec<PQIndex>,
multi_codebook_pq: Vec<PQIndex>,
codebook_weights: Vec<f32>,
}
impl InvertedList {
fn new() -> Self {
Self {
vectors: Vec::new(),
quantization: QuantizationStrategy::None,
pq_index: None,
multi_level_pq: Vec::new(),
multi_codebook_pq: Vec::new(),
codebook_weights: Vec::new(),
}
}
fn new_with_quantization(quantization: QuantizationStrategy) -> Result<Self> {
let mut list = Self {
vectors: Vec::new(),
quantization: quantization.clone(),
pq_index: None,
multi_level_pq: Vec::new(),
multi_codebook_pq: Vec::new(),
codebook_weights: Vec::new(),
};
match quantization {
QuantizationStrategy::None => {}
QuantizationStrategy::ProductQuantization(pq_config) => {
list.pq_index = Some(PQIndex::new(pq_config));
}
QuantizationStrategy::ResidualQuantization {
levels: _,
ref pq_configs,
} => {
for pq_config in pq_configs {
list.multi_level_pq.push(PQIndex::new(pq_config.clone()));
}
}
QuantizationStrategy::MultiCodebook {
num_codebooks,
ref pq_configs,
} => {
for pq_config in pq_configs {
list.multi_codebook_pq.push(PQIndex::new(pq_config.clone()));
}
list.codebook_weights = vec![1.0 / num_codebooks as f32; num_codebooks];
}
}
Ok(list)
}
fn new_with_pq(pq_config: PQConfig) -> Result<Self> {
Self::new_with_quantization(QuantizationStrategy::ProductQuantization(pq_config))
}
fn add_full(&mut self, uri: String, vector: Vector) {
self.vectors.push((uri, VectorStorage::Full(vector)));
}
fn add_residual(&mut self, uri: String, residual: Vector, _centroid: &Vector) -> Result<()> {
match &self.quantization {
QuantizationStrategy::ProductQuantization(_) => {
if let Some(ref mut pq_index) = self.pq_index {
if !pq_index.is_trained() {
let training_residuals = vec![residual.clone()];
pq_index.train(&training_residuals)?;
}
let codes = pq_index.encode(&residual)?;
self.vectors.push((uri, VectorStorage::Quantized(codes)));
} else {
return Err(anyhow!(
"PQ index not initialized for residual quantization"
));
}
}
QuantizationStrategy::ResidualQuantization { levels, .. } => {
self.add_multi_level_residual(uri, residual, *levels)?;
}
QuantizationStrategy::MultiCodebook { .. } => {
self.add_multi_codebook(uri, residual)?;
}
QuantizationStrategy::None => {
self.add_full(uri, residual);
}
}
Ok(())
}
fn add_multi_level_residual(
&mut self,
uri: String,
mut residual: Vector,
levels: usize,
) -> Result<()> {
let mut level_codes = Vec::new();
for level in 0..levels.min(self.multi_level_pq.len()) {
if !self.multi_level_pq[level].is_trained() {
let training_residuals = vec![residual.clone()];
self.multi_level_pq[level].train(&training_residuals)?;
}
let codes = self.multi_level_pq[level].encode(&residual)?;
level_codes.push(codes);
let approximation = self.multi_level_pq[level].decode_vector(&level_codes[level])?;
residual = residual.subtract(&approximation)?;
}
let final_residual = if level_codes.len() < levels {
Some(residual)
} else {
None
};
self.vectors.push((
uri,
VectorStorage::MultiLevelQuantized {
levels: level_codes,
final_residual,
},
));
Ok(())
}
fn add_multi_codebook(&mut self, uri: String, residual: Vector) -> Result<()> {
let mut codebook_codes = Vec::new();
for pq_index in self.multi_codebook_pq.iter_mut() {
if !pq_index.is_trained() {
let training_residuals = vec![residual.clone()];
pq_index.train(&training_residuals)?;
}
let codes = pq_index.encode(&residual)?;
codebook_codes.push(codes);
}
self.vectors.push((
uri,
VectorStorage::MultiCodebook {
codebooks: codebook_codes,
weights: self.codebook_weights.clone(),
},
));
Ok(())
}
fn search(&self, query: &Vector, centroid: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
let mut distances: Vec<(String, f32)> = Vec::new();
let query_residual = query.subtract(centroid)?;
for (uri, storage) in &self.vectors {
let distance = match storage {
VectorStorage::Full(vec) => query.euclidean_distance(vec).unwrap_or(f32::INFINITY),
VectorStorage::Quantized(codes) => {
if let Some(ref pq_index) = self.pq_index {
pq_index.compute_distance(&query_residual, codes)?
} else {
f32::INFINITY
}
}
VectorStorage::MultiLevelQuantized {
levels,
final_residual,
} => self.compute_multi_level_distance(&query_residual, levels, final_residual)?,
VectorStorage::MultiCodebook { codebooks, weights } => {
self.compute_multi_codebook_distance(&query_residual, codebooks, weights)?
}
};
distances.push((uri.clone(), distance));
}
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
distances.truncate(k);
Ok(distances
.into_iter()
.map(|(uri, dist)| (uri, 1.0 / (1.0 + dist)))
.collect())
}
fn compute_multi_level_distance(
&self,
query_residual: &Vector,
level_codes: &[Vec<u8>],
final_residual: &Option<Vector>,
) -> Result<f32> {
let mut reconstructed_residual = Vector::new(vec![0.0; query_residual.dimensions]);
for (level, codes) in level_codes.iter().enumerate() {
if level < self.multi_level_pq.len() {
let level_reconstruction = self.multi_level_pq[level].decode_vector(codes)?;
reconstructed_residual = reconstructed_residual.add(&level_reconstruction)?;
}
}
if let Some(final_res) = final_residual {
reconstructed_residual = reconstructed_residual.add(final_res)?;
}
query_residual.euclidean_distance(&reconstructed_residual)
}
fn compute_multi_codebook_distance(
&self,
query_residual: &Vector,
codebook_codes: &[Vec<u8>],
weights: &[f32],
) -> Result<f32> {
let mut weighted_distance = 0.0;
let mut total_weight = 0.0;
for (i, codes) in codebook_codes.iter().enumerate() {
if i < self.multi_codebook_pq.len() && i < weights.len() {
let codebook_distance =
self.multi_codebook_pq[i].compute_distance(query_residual, codes)?;
weighted_distance += weights[i] * codebook_distance;
total_weight += weights[i];
}
}
if total_weight > 0.0 {
Ok(weighted_distance / total_weight)
} else {
Ok(f32::INFINITY)
}
}
fn train_pq(&mut self, residuals: &[Vector]) -> Result<()> {
match &self.quantization {
QuantizationStrategy::ProductQuantization(_) => {
if let Some(ref mut pq_index) = self.pq_index {
pq_index.train(residuals)?;
}
}
QuantizationStrategy::ResidualQuantization { levels, .. } => {
self.train_multi_level_pq(residuals, *levels)?;
}
QuantizationStrategy::MultiCodebook { .. } => {
self.train_multi_codebook_pq(residuals)?;
}
QuantizationStrategy::None => {}
}
Ok(())
}
fn train_multi_level_pq(&mut self, residuals: &[Vector], levels: usize) -> Result<()> {
let mut current_residuals = residuals.to_vec();
for level in 0..levels.min(self.multi_level_pq.len()) {
self.multi_level_pq[level].train(¤t_residuals)?;
let mut next_residuals = Vec::new();
for residual in ¤t_residuals {
let codes = self.multi_level_pq[level].encode(residual)?;
let approximation = self.multi_level_pq[level].decode_vector(&codes)?;
let next_residual = residual.subtract(&approximation)?;
next_residuals.push(next_residual);
}
current_residuals = next_residuals;
}
Ok(())
}
fn train_multi_codebook_pq(&mut self, residuals: &[Vector]) -> Result<()> {
for pq_index in &mut self.multi_codebook_pq {
pq_index.train(residuals)?;
}
self.optimize_codebook_weights(residuals)?;
Ok(())
}
fn optimize_codebook_weights(&mut self, residuals: &[Vector]) -> Result<()> {
if self.multi_codebook_pq.is_empty() || residuals.is_empty() {
return Ok(());
}
let num_codebooks = self.multi_codebook_pq.len();
let mut reconstruction_errors = vec![0.0; num_codebooks];
for (i, pq_index) in self.multi_codebook_pq.iter().enumerate() {
let mut total_error = 0.0;
for residual in residuals {
let codes = pq_index.encode(residual)?;
let reconstruction = pq_index.decode_vector(&codes)?;
let error = residual
.euclidean_distance(&reconstruction)
.unwrap_or(f32::INFINITY);
total_error += error;
}
reconstruction_errors[i] = total_error / residuals.len() as f32;
}
let max_error = reconstruction_errors.iter().fold(0.0f32, |a, &b| a.max(b));
if max_error > 0.0 {
let mut total_weight = 0.0;
for (i, &error) in reconstruction_errors.iter().enumerate().take(num_codebooks) {
self.codebook_weights[i] = (max_error - error + 1e-6) / max_error;
total_weight += self.codebook_weights[i];
}
if total_weight > 0.0 {
for weight in &mut self.codebook_weights {
*weight /= total_weight;
}
}
}
Ok(())
}
fn stats(&self) -> InvertedListStats {
let mut full_vectors = 0;
let mut quantized_vectors = 0;
let mut multi_level_vectors = 0;
let mut multi_codebook_vectors = 0;
for (_, storage) in &self.vectors {
match storage {
VectorStorage::Full(_) => full_vectors += 1,
VectorStorage::Quantized(_) => quantized_vectors += 1,
VectorStorage::MultiLevelQuantized { .. } => {
quantized_vectors += 1;
multi_level_vectors += 1;
}
VectorStorage::MultiCodebook { .. } => {
quantized_vectors += 1;
multi_codebook_vectors += 1;
}
}
}
let total_vectors = self.vectors.len();
let compression_ratio = if total_vectors > 0 {
quantized_vectors as f32 / total_vectors as f32
} else {
0.0
};
InvertedListStats {
total_vectors,
full_vectors,
quantized_vectors,
compression_ratio,
multi_level_vectors,
multi_codebook_vectors,
quantization_strategy: self.quantization.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct InvertedListStats {
pub total_vectors: usize,
pub full_vectors: usize,
pub quantized_vectors: usize,
pub compression_ratio: f32,
pub multi_level_vectors: usize,
pub multi_codebook_vectors: usize,
pub quantization_strategy: QuantizationStrategy,
}
pub struct IvfIndex {
config: IvfConfig,
centroids: Vec<Vector>,
inverted_lists: Vec<Arc<RwLock<InvertedList>>>,
dimensions: Option<usize>,
n_vectors: usize,
is_trained: bool,
}
impl IvfIndex {
pub fn new(config: IvfConfig) -> Result<Self> {
let mut inverted_lists = Vec::with_capacity(config.n_clusters);
let quantization = if config.enable_residual_quantization {
if let Some(ref pq_config) = config.pq_config {
QuantizationStrategy::ProductQuantization(pq_config.clone())
} else {
return Err(anyhow!(
"PQ config required when residual quantization is enabled"
));
}
} else {
config.quantization.clone()
};
for _ in 0..config.n_clusters {
let inverted_list = Arc::new(RwLock::new(InvertedList::new_with_quantization(
quantization.clone(),
)?));
inverted_lists.push(inverted_list);
}
Ok(Self {
config,
centroids: Vec::new(),
inverted_lists,
dimensions: None,
n_vectors: 0,
is_trained: false,
})
}
pub fn new_with_product_quantization(
n_clusters: usize,
n_probes: usize,
pq_config: PQConfig,
) -> Result<Self> {
let config = IvfConfig {
n_clusters,
n_probes,
quantization: QuantizationStrategy::ProductQuantization(pq_config),
..Default::default()
};
Self::new(config)
}
pub fn new_with_multi_level_quantization(
n_clusters: usize,
n_probes: usize,
levels: usize,
pq_configs: Vec<PQConfig>,
) -> Result<Self> {
if pq_configs.len() < levels {
return Err(anyhow!(
"Number of PQ configs must be at least equal to levels"
));
}
let config = IvfConfig {
n_clusters,
n_probes,
quantization: QuantizationStrategy::ResidualQuantization { levels, pq_configs },
..Default::default()
};
Self::new(config)
}
pub fn new_with_multi_codebook_quantization(
n_clusters: usize,
n_probes: usize,
num_codebooks: usize,
pq_configs: Vec<PQConfig>,
) -> Result<Self> {
if pq_configs.len() != num_codebooks {
return Err(anyhow!(
"Number of PQ configs must equal number of codebooks"
));
}
let config = IvfConfig {
n_clusters,
n_probes,
quantization: QuantizationStrategy::MultiCodebook {
num_codebooks,
pq_configs,
},
..Default::default()
};
Self::new(config)
}
pub fn new_with_residual_quantization(
n_clusters: usize,
n_probes: usize,
pq_config: PQConfig,
) -> Result<Self> {
Self::new_with_product_quantization(n_clusters, n_probes, pq_config)
}
pub fn config(&self) -> &IvfConfig {
&self.config
}
pub fn train(&mut self, training_vectors: &[Vector]) -> Result<()> {
if training_vectors.is_empty() {
return Err(anyhow!("Cannot train IVF index with empty training set"));
}
let dims = training_vectors[0].dimensions;
if !training_vectors.iter().all(|v| v.dimensions == dims) {
return Err(anyhow!(
"All training vectors must have the same dimensions"
));
}
self.dimensions = Some(dims);
self.centroids = self.initialize_centroids_kmeans_plus_plus(training_vectors)?;
let mut iteration = 0;
let mut prev_error = f32::INFINITY;
while iteration < self.config.max_iterations {
let mut clusters: Vec<Vec<&Vector>> = vec![Vec::new(); self.config.n_clusters];
for vector in training_vectors {
let nearest_idx = self.find_nearest_centroid(vector)?;
clusters[nearest_idx].push(vector);
}
let mut total_error = 0.0;
for (i, cluster) in clusters.iter().enumerate() {
if !cluster.is_empty() {
let new_centroid = self.compute_centroid(cluster);
total_error += self.centroids[i]
.euclidean_distance(&new_centroid)
.unwrap_or(0.0);
self.centroids[i] = new_centroid;
}
}
if (prev_error - total_error).abs() < self.config.convergence_threshold {
break;
}
prev_error = total_error;
iteration += 1;
}
self.is_trained = true;
if !matches!(self.config.quantization, QuantizationStrategy::None)
|| self.config.enable_residual_quantization
{
self.train_residual_quantization(training_vectors)?;
}
Ok(())
}
fn train_residual_quantization(&mut self, training_vectors: &[Vector]) -> Result<()> {
let mut cluster_residuals: Vec<Vec<Vector>> = vec![Vec::new(); self.config.n_clusters];
for vector in training_vectors {
let cluster_idx = self.find_nearest_centroid(vector)?;
let centroid = &self.centroids[cluster_idx];
let residual = vector.subtract(centroid)?;
cluster_residuals[cluster_idx].push(residual);
}
for (cluster_idx, residuals) in cluster_residuals.iter().enumerate() {
if residuals.len() > 10 {
let mut list = self.inverted_lists[cluster_idx]
.write()
.expect("inverted_lists lock should not be poisoned");
list.train_pq(residuals)?;
}
}
Ok(())
}
fn initialize_centroids_kmeans_plus_plus(&self, vectors: &[Vector]) -> Result<Vec<Vector>> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
self.config.seed.unwrap_or(42).hash(&mut hasher);
let mut rng_state = hasher.finish();
let mut centroids = Vec::with_capacity(self.config.n_clusters);
let first_idx = (rng_state as usize) % vectors.len();
centroids.push(vectors[first_idx].clone());
while centroids.len() < self.config.n_clusters {
let mut distances = Vec::with_capacity(vectors.len());
let mut sum_distances = 0.0;
for vector in vectors {
let min_dist = centroids
.iter()
.map(|c| vector.euclidean_distance(c).unwrap_or(f32::INFINITY))
.fold(f32::INFINITY, |a, b| a.min(b));
distances.push(min_dist * min_dist); sum_distances += min_dist * min_dist;
}
rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
let threshold = (rng_state as f32 / u64::MAX as f32) * sum_distances;
let mut cumulative = 0.0;
for (i, &dist) in distances.iter().enumerate() {
cumulative += dist;
if cumulative >= threshold {
centroids.push(vectors[i].clone());
break;
}
}
}
Ok(centroids)
}
fn compute_centroid(&self, cluster: &[&Vector]) -> Vector {
if cluster.is_empty() {
return Vector::new(vec![0.0; self.dimensions.unwrap_or(0)]);
}
let dims = cluster[0].dimensions;
let mut sum = vec![0.0; dims];
for vector in cluster {
let values = vector.as_f32();
for (i, &val) in values.iter().enumerate() {
sum[i] += val;
}
}
let count = cluster.len() as f32;
for val in &mut sum {
*val /= count;
}
Vector::new(sum)
}
fn find_nearest_centroid(&self, vector: &Vector) -> Result<usize> {
if self.centroids.is_empty() {
return Err(anyhow!("No centroids available"));
}
let mut min_distance = f32::INFINITY;
let mut nearest_idx = 0;
for (i, centroid) in self.centroids.iter().enumerate() {
let distance = vector.euclidean_distance(centroid)?;
if distance < min_distance {
min_distance = distance;
nearest_idx = i;
}
}
Ok(nearest_idx)
}
fn find_nearest_centroids(&self, query: &Vector, n_probes: usize) -> Result<Vec<usize>> {
let mut distances: Vec<(usize, f32)> = self
.centroids
.iter()
.enumerate()
.map(|(i, centroid)| {
let dist = query.euclidean_distance(centroid).unwrap_or(f32::INFINITY);
(i, dist)
})
.collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(distances
.into_iter()
.take(n_probes.min(self.centroids.len()))
.map(|(i, _)| i)
.collect())
}
pub fn stats(&self) -> IvfStats {
let mut total_list_stats = InvertedListStats {
total_vectors: 0,
full_vectors: 0,
quantized_vectors: 0,
compression_ratio: 0.0,
multi_level_vectors: 0,
multi_codebook_vectors: 0,
quantization_strategy: QuantizationStrategy::None,
};
let mut cluster_stats = Vec::new();
let mut vectors_per_cluster = Vec::new();
let mut non_empty_clusters = 0;
for list in &self.inverted_lists {
let list_guard = list
.read()
.expect("inverted list lock should not be poisoned");
let stats = list_guard.stats();
total_list_stats.total_vectors += stats.total_vectors;
total_list_stats.full_vectors += stats.full_vectors;
total_list_stats.quantized_vectors += stats.quantized_vectors;
total_list_stats.multi_level_vectors += stats.multi_level_vectors;
total_list_stats.multi_codebook_vectors += stats.multi_codebook_vectors;
vectors_per_cluster.push(stats.total_vectors);
if stats.total_vectors > 0 {
non_empty_clusters += 1;
}
cluster_stats.push(stats);
}
if total_list_stats.total_vectors > 0 {
total_list_stats.compression_ratio =
total_list_stats.quantized_vectors as f32 / total_list_stats.total_vectors as f32;
}
let avg_vectors_per_cluster = if self.config.n_clusters > 0 {
self.n_vectors as f32 / self.config.n_clusters as f32
} else {
0.0
};
IvfStats {
n_clusters: self.config.n_clusters,
n_probes: self.config.n_probes,
n_vectors: self.n_vectors,
is_trained: self.is_trained,
dimensions: self.dimensions,
vectors_per_cluster,
avg_vectors_per_cluster,
non_empty_clusters,
enable_residual_quantization: self.config.enable_residual_quantization,
quantization_strategy: self.config.quantization.clone(),
compression_stats: Some(total_list_stats),
cluster_stats,
}
}
}
impl VectorIndex for IvfIndex {
fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
if !self.is_trained {
return Err(anyhow!(
"IVF index must be trained before inserting vectors"
));
}
if let Some(dims) = self.dimensions {
if vector.dimensions != dims {
return Err(anyhow!(
"Vector dimensions {} don't match index dimensions {}",
vector.dimensions,
dims
));
}
}
let cluster_idx = self.find_nearest_centroid(&vector)?;
let centroid = &self.centroids[cluster_idx];
let mut list = self.inverted_lists[cluster_idx]
.write()
.expect("inverted_lists lock should not be poisoned");
match &self.config.quantization {
QuantizationStrategy::None => {
if self.config.enable_residual_quantization {
let residual = vector.subtract(centroid)?;
list.add_residual(uri, residual, centroid)?;
} else {
list.add_full(uri, vector);
}
}
_ => {
let residual = vector.subtract(centroid)?;
list.add_residual(uri, residual, centroid)?;
}
}
self.n_vectors += 1;
Ok(())
}
fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
if !self.is_trained {
return Err(anyhow!("IVF index must be trained before searching"));
}
let probe_indices = self.find_nearest_centroids(query, self.config.n_probes)?;
let mut all_results = Vec::new();
for idx in probe_indices {
let list = self.inverted_lists[idx]
.read()
.expect("inverted_lists lock should not be poisoned");
let centroid = &self.centroids[idx];
let mut results = list.search(query, centroid, k)?;
all_results.append(&mut results);
}
all_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
all_results.truncate(k);
Ok(all_results)
}
fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
if !self.is_trained {
return Err(anyhow!("IVF index must be trained before searching"));
}
let probe_indices = self.find_nearest_centroids(query, self.config.n_probes)?;
let mut all_results = Vec::new();
for idx in probe_indices {
let list = self.inverted_lists[idx]
.read()
.expect("inverted_lists lock should not be poisoned");
let centroid = &self.centroids[idx];
let results = list.search(query, centroid, self.n_vectors)?;
for (uri, similarity) in results {
if similarity >= threshold {
all_results.push((uri, similarity));
}
}
}
all_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(all_results)
}
fn get_vector(&self, _uri: &str) -> Option<&Vector> {
None
}
}
#[derive(Debug, Clone)]
pub struct IvfStats {
pub n_vectors: usize,
pub n_clusters: usize,
pub n_probes: usize,
pub is_trained: bool,
pub dimensions: Option<usize>,
pub vectors_per_cluster: Vec<usize>,
pub avg_vectors_per_cluster: f32,
pub non_empty_clusters: usize,
pub enable_residual_quantization: bool,
pub quantization_strategy: QuantizationStrategy,
pub compression_stats: Option<InvertedListStats>,
pub cluster_stats: Vec<InvertedListStats>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ivf_basic() -> Result<()> {
let config = IvfConfig {
n_clusters: 4,
n_probes: 2,
..Default::default()
};
let mut index = IvfIndex::new(config)?;
let training_vectors = vec![
Vector::new(vec![1.0, 0.0]),
Vector::new(vec![0.0, 1.0]),
Vector::new(vec![-1.0, 0.0]),
Vector::new(vec![0.0, -1.0]),
Vector::new(vec![0.5, 0.5]),
Vector::new(vec![-0.5, 0.5]),
Vector::new(vec![-0.5, -0.5]),
Vector::new(vec![0.5, -0.5]),
];
index.train(&training_vectors)?;
assert!(index.is_trained);
for (i, vec) in training_vectors.iter().enumerate() {
index.insert(format!("vec{i}"), vec.clone())?;
}
let query = Vector::new(vec![0.9, 0.1]);
let results = index.search_knn(&query, 3)?;
assert!(!results.is_empty());
assert!(results.len() <= 3);
assert_eq!(results[0].0, "vec0");
Ok(())
}
#[test]
fn test_ivf_threshold_search() -> Result<()> {
let config = IvfConfig {
n_clusters: 2,
n_probes: 2,
..Default::default()
};
let mut index = IvfIndex::new(config)?;
let training_vectors = vec![
Vector::new(vec![1.0, 0.0, 0.0]),
Vector::new(vec![0.0, 1.0, 0.0]),
Vector::new(vec![0.0, 0.0, 1.0]),
Vector::new(vec![0.5, 0.5, 0.0]),
];
index.train(&training_vectors)?;
index.insert("v1".to_string(), training_vectors[0].clone())?;
index.insert("v2".to_string(), training_vectors[1].clone())?;
index.insert("v3".to_string(), training_vectors[2].clone())?;
index.insert("v4".to_string(), training_vectors[3].clone())?;
let query = Vector::new(vec![0.9, 0.1, 0.0]);
let results = index.search_threshold(&query, 0.5)?;
assert!(!results.is_empty());
for (_, similarity) in &results {
assert!(*similarity >= 0.5);
}
Ok(())
}
#[test]
fn test_ivf_stats() -> Result<()> {
let config = IvfConfig {
n_clusters: 3,
n_probes: 1,
..Default::default()
};
let mut index = IvfIndex::new(config)?;
let training_vectors = vec![
Vector::new(vec![1.0, 0.0]),
Vector::new(vec![0.0, 1.0]),
Vector::new(vec![-1.0, -1.0]),
];
index.train(&training_vectors)?;
index.insert("a".to_string(), Vector::new(vec![1.1, 0.1]))?;
index.insert("b".to_string(), Vector::new(vec![0.1, 1.1]))?;
let stats = index.stats();
assert_eq!(stats.n_vectors, 2);
assert_eq!(stats.n_clusters, 3);
assert!(stats.is_trained);
assert_eq!(stats.dimensions, Some(2));
Ok(())
}
#[test]
fn test_ivf_multi_level_quantization() -> Result<()> {
use crate::pq::PQConfig;
let pq_config_1 = PQConfig {
n_subquantizers: 2,
n_bits: 8,
..Default::default()
};
let pq_config_2 = PQConfig {
n_subquantizers: 2,
n_bits: 4,
..Default::default()
};
let mut index =
IvfIndex::new_with_multi_level_quantization(4, 2, 2, vec![pq_config_1, pq_config_2])?;
let training_vectors = vec![
Vector::new(vec![1.0, 0.0, 0.0, 0.0]),
Vector::new(vec![0.0, 1.0, 0.0, 0.0]),
Vector::new(vec![0.0, 0.0, 1.0, 0.0]),
Vector::new(vec![0.0, 0.0, 0.0, 1.0]),
Vector::new(vec![0.5, 0.5, 0.0, 0.0]),
Vector::new(vec![0.0, 0.0, 0.5, 0.5]),
];
index.train(&training_vectors)?;
assert!(index.is_trained);
for (i, vec) in training_vectors.iter().enumerate() {
index.insert(format!("vec{i}"), vec.clone())?;
}
let query = Vector::new(vec![0.9, 0.1, 0.0, 0.0]);
let results = index.search_knn(&query, 3)?;
assert!(!results.is_empty());
assert!(results.len() <= 3);
let stats = index.stats();
assert!(matches!(
stats.quantization_strategy,
QuantizationStrategy::ResidualQuantization { .. }
));
if let Some(compression_stats) = &stats.compression_stats {
assert!(compression_stats.multi_level_vectors > 0);
}
Ok(())
}
#[test]
fn test_ivf_multi_codebook_quantization() -> Result<()> {
use crate::pq::PQConfig;
let pq_config_1 = PQConfig {
n_subquantizers: 2,
n_bits: 8,
..Default::default()
};
let pq_config_2 = PQConfig {
n_subquantizers: 2,
n_bits: 8,
..Default::default()
};
let mut index = IvfIndex::new_with_multi_codebook_quantization(
4,
2,
2,
vec![pq_config_1, pq_config_2],
)?;
let training_vectors = vec![
Vector::new(vec![1.0, 0.0, 0.0, 0.0]),
Vector::new(vec![0.0, 1.0, 0.0, 0.0]),
Vector::new(vec![0.0, 0.0, 1.0, 0.0]),
Vector::new(vec![0.0, 0.0, 0.0, 1.0]),
Vector::new(vec![0.5, 0.5, 0.5, 0.5]),
];
index.train(&training_vectors)?;
assert!(index.is_trained);
for (i, vec) in training_vectors.iter().enumerate() {
index.insert(format!("vec{i}"), vec.clone())?;
}
let query = Vector::new(vec![0.9, 0.1, 0.0, 0.0]);
let results = index.search_knn(&query, 2)?;
assert!(!results.is_empty());
assert!(results.len() <= 2);
let stats = index.stats();
assert!(matches!(
stats.quantization_strategy,
QuantizationStrategy::MultiCodebook { .. }
));
if let Some(compression_stats) = &stats.compression_stats {
assert!(compression_stats.multi_codebook_vectors > 0);
}
Ok(())
}
#[test]
fn test_quantization_strategies() {
use crate::pq::PQConfig;
let pq_config = PQConfig::default();
let strategies = vec![
QuantizationStrategy::None,
QuantizationStrategy::ProductQuantization(pq_config.clone()),
QuantizationStrategy::ResidualQuantization {
levels: 2,
pq_configs: vec![pq_config.clone(), pq_config.clone()],
},
QuantizationStrategy::MultiCodebook {
num_codebooks: 2,
pq_configs: vec![pq_config.clone(), pq_config.clone()],
},
];
for strategy in strategies {
let config = IvfConfig {
n_clusters: 2,
n_probes: 1,
quantization: strategy.clone(),
..Default::default()
};
let index = IvfIndex::new(config);
assert!(
index.is_ok(),
"Failed to create index with strategy: {strategy:?}"
);
}
}
}