use std::borrow::Cow;
use std::collections::HashMap;
#[cfg(not(target_arch = "wasm32"))]
use rayon::prelude::*;
use crate::types::NodeId;
use crate::vector::distance::normalize;
use crate::vector::ivf::{kmeans_parallel, KMeansConfig};
use crate::vector::types::{
DistanceMetric, IvfConfig, MultiQueryAggregation, PqConfig, VectorManifest, VectorSearchResult,
};
#[derive(Debug, Clone)]
pub struct IvfPqConfig {
pub ivf: IvfConfig,
pub pq: PqConfig,
pub use_residuals: bool,
}
impl Default for IvfPqConfig {
fn default() -> Self {
Self {
ivf: IvfConfig::default(),
pq: PqConfig::default(),
use_residuals: true,
}
}
}
impl IvfPqConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_n_clusters(mut self, n_clusters: usize) -> Self {
self.ivf.n_clusters = n_clusters;
self
}
pub fn with_n_probe(mut self, n_probe: usize) -> Self {
self.ivf.n_probe = n_probe;
self
}
pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
self.ivf.metric = metric;
self
}
pub fn with_num_subspaces(mut self, num_subspaces: usize) -> Self {
self.pq.num_subspaces = num_subspaces;
self
}
pub fn with_num_centroids(mut self, num_centroids: usize) -> Self {
self.pq.num_centroids = num_centroids;
self
}
pub fn with_residuals(mut self, use_residuals: bool) -> Self {
self.use_residuals = use_residuals;
self
}
}
#[derive(Debug)]
pub struct IvfPqIndex {
pub config: IvfPqConfig,
pub ivf_centroids: Vec<f32>,
pub inverted_lists: HashMap<usize, Vec<u64>>,
pub pq_codes: HashMap<u64, Vec<u8>>,
pub pq_centroids: Vec<Vec<f32>>,
pub centroid_distances: Option<Vec<f32>>,
pub dimensions: usize,
pub subspace_dims: usize,
pub trained: bool,
training_vectors: Option<Vec<f32>>,
training_count: usize,
}
impl IvfPqIndex {
pub fn new(dimensions: usize, config: IvfPqConfig) -> Result<Self, IvfPqError> {
if dimensions % config.pq.num_subspaces != 0 {
return Err(IvfPqError::DimensionNotDivisible {
dimensions,
num_subspaces: config.pq.num_subspaces,
});
}
let subspace_dims = dimensions / config.pq.num_subspaces;
let pq_centroids: Vec<Vec<f32>> = (0..config.pq.num_subspaces)
.map(|_| vec![0.0; config.pq.num_centroids * subspace_dims])
.collect();
Ok(Self {
config,
ivf_centroids: Vec::new(),
inverted_lists: HashMap::new(),
pq_codes: HashMap::new(),
pq_centroids,
centroid_distances: None,
dimensions,
subspace_dims,
trained: false,
training_vectors: Some(Vec::new()),
training_count: 0,
})
}
pub fn with_defaults(dimensions: usize) -> Result<Self, IvfPqError> {
Self::new(dimensions, IvfPqConfig::default())
}
#[allow(clippy::too_many_arguments)]
pub fn from_serialized(
config: IvfPqConfig,
ivf_centroids: Vec<f32>,
inverted_lists: HashMap<usize, Vec<u64>>,
pq_codes: HashMap<u64, Vec<u8>>,
pq_centroids: Vec<Vec<f32>>,
centroid_distances: Option<Vec<f32>>,
dimensions: usize,
trained: bool,
) -> Result<Self, IvfPqError> {
if dimensions % config.pq.num_subspaces != 0 {
return Err(IvfPqError::DimensionNotDivisible {
dimensions,
num_subspaces: config.pq.num_subspaces,
});
}
let subspace_dims = dimensions / config.pq.num_subspaces;
Ok(Self {
config,
ivf_centroids,
inverted_lists,
pq_codes,
pq_centroids,
centroid_distances,
dimensions,
subspace_dims,
trained,
training_vectors: None,
training_count: 0,
})
}
pub fn add_training_vectors(&mut self, vectors: &[f32], count: usize) -> Result<(), IvfPqError> {
if self.trained {
return Err(IvfPqError::AlreadyTrained);
}
let expected_len = count * self.dimensions;
if vectors.len() < expected_len {
return Err(IvfPqError::DimensionMismatch {
expected: expected_len,
got: vectors.len(),
});
}
let training_buf = self.training_vectors.get_or_insert_with(Vec::new);
training_buf.extend_from_slice(&vectors[..expected_len]);
self.training_count += count;
Ok(())
}
pub fn train(&mut self) -> Result<(), IvfPqError> {
if self.trained {
return Ok(());
}
let training_vectors = self
.training_vectors
.take()
.ok_or(IvfPqError::NoTrainingVectors)?;
let n = self.training_count;
let n_clusters = self.config.ivf.n_clusters;
if n < n_clusters {
return Err(IvfPqError::NotEnoughTrainingVectors { n, k: n_clusters });
}
if n < self.config.pq.num_centroids {
return Err(IvfPqError::NotEnoughTrainingVectors {
n,
k: self.config.pq.num_centroids,
});
}
let distance_fn = self.config.ivf.metric.distance_fn();
let kmeans_config = KMeansConfig::new(n_clusters)
.with_max_iterations(25)
.with_tolerance(1e-4);
let kmeans_result = kmeans_parallel(
&training_vectors,
n,
self.dimensions,
&kmeans_config,
distance_fn,
)
.map_err(|e| IvfPqError::TrainingFailed(e.to_string()))?;
self.ivf_centroids = kmeans_result.centroids;
let assignments = kmeans_result.assignments;
if self.config.use_residuals {
let mut residuals = vec![0.0f32; n * self.dimensions];
for (i, &cluster_id) in assignments.iter().enumerate().take(n) {
let cluster = cluster_id as usize;
let vec_offset = i * self.dimensions;
let cent_offset = cluster * self.dimensions;
for d in 0..self.dimensions {
residuals[vec_offset + d] =
training_vectors[vec_offset + d] - self.ivf_centroids[cent_offset + d];
}
}
self.train_pq(&residuals, n)?;
} else {
self.train_pq(&training_vectors, n)?;
}
let mut centroid_distances = vec![0.0f32; n_clusters * n_clusters];
for i in 0..n_clusters {
let ci = &self.ivf_centroids[i * self.dimensions..(i + 1) * self.dimensions];
for j in i..n_clusters {
let cj = &self.ivf_centroids[j * self.dimensions..(j + 1) * self.dimensions];
let dist = distance_fn(ci, cj);
centroid_distances[i * n_clusters + j] = dist;
centroid_distances[j * n_clusters + i] = dist;
}
}
self.centroid_distances = Some(centroid_distances);
for c in 0..n_clusters {
self.inverted_lists.insert(c, Vec::new());
}
self.trained = true;
self.training_vectors = None;
self.training_count = 0;
Ok(())
}
fn train_pq(&mut self, vectors: &[f32], num_vectors: usize) -> Result<(), IvfPqError> {
let num_subspaces = self.config.pq.num_subspaces;
let num_centroids = self.config.pq.num_centroids;
let max_iterations = self.config.pq.max_iterations;
let subspace_dims = self.subspace_dims;
let dimensions = self.dimensions;
let trained_centroids: Vec<Vec<f32>> = {
#[cfg(not(target_arch = "wasm32"))]
{
(0..num_subspaces)
.into_par_iter()
.map(|m| {
let mut subvectors = Vec::with_capacity(num_vectors * subspace_dims);
let sub_offset = m * subspace_dims;
for i in 0..num_vectors {
let vec_offset = i * dimensions + sub_offset;
subvectors.extend_from_slice(&vectors[vec_offset..vec_offset + subspace_dims]);
}
let mut centroids = vec![0.0f32; num_centroids * subspace_dims];
train_pq_subspace(
&mut centroids,
&subvectors,
num_vectors,
subspace_dims,
num_centroids,
max_iterations,
);
centroids
})
.collect()
}
#[cfg(target_arch = "wasm32")]
{
(0..num_subspaces)
.map(|m| {
let mut subvectors = Vec::with_capacity(num_vectors * subspace_dims);
let sub_offset = m * subspace_dims;
for i in 0..num_vectors {
let vec_offset = i * dimensions + sub_offset;
subvectors.extend_from_slice(&vectors[vec_offset..vec_offset + subspace_dims]);
}
let mut centroids = vec![0.0f32; num_centroids * subspace_dims];
train_pq_subspace(
&mut centroids,
&subvectors,
num_vectors,
subspace_dims,
num_centroids,
max_iterations,
);
centroids
})
.collect()
}
};
for (m, centroids) in trained_centroids.into_iter().enumerate() {
self.pq_centroids[m] = centroids;
}
Ok(())
}
pub fn insert(&mut self, vector_id: u64, vector: &[f32]) -> Result<(), IvfPqError> {
if !self.trained {
return Err(IvfPqError::NotTrained);
}
if vector.len() != self.dimensions {
return Err(IvfPqError::DimensionMismatch {
expected: self.dimensions,
got: vector.len(),
});
}
let distance_fn = self.config.ivf.metric.distance_fn();
let query_vec: Cow<[f32]> = if self.config.ivf.metric == DistanceMetric::Cosine {
Cow::Owned(normalize(vector))
} else {
Cow::Borrowed(vector)
};
let query_slice = query_vec.as_ref();
let mut best_cluster = 0;
let mut best_dist = f32::INFINITY;
for c in 0..self.config.ivf.n_clusters {
let cent_offset = c * self.dimensions;
let centroid = &self.ivf_centroids[cent_offset..cent_offset + self.dimensions];
let dist = distance_fn(query_slice, centroid);
if dist < best_dist {
best_dist = dist;
best_cluster = c;
}
}
let codes = if self.config.use_residuals {
let cent_offset = best_cluster * self.dimensions;
let residuals: Vec<f32> = query_slice
.iter()
.zip(&self.ivf_centroids[cent_offset..cent_offset + self.dimensions])
.map(|(v, c)| v - c)
.collect();
self.encode_single_vector(&residuals)
} else {
self.encode_single_vector(query_slice)
};
self
.inverted_lists
.entry(best_cluster)
.or_default()
.push(vector_id);
self.pq_codes.insert(vector_id, codes);
Ok(())
}
fn encode_single_vector(&self, vector: &[f32]) -> Vec<u8> {
let num_subspaces = self.config.pq.num_subspaces;
let num_centroids = self.config.pq.num_centroids;
let mut codes = vec![0u8; num_subspaces];
for (m, code) in codes.iter_mut().enumerate().take(num_subspaces) {
let sub_offset = m * self.subspace_dims;
let subvec = &vector[sub_offset..sub_offset + self.subspace_dims];
let mut best_centroid = 0;
let mut best_dist = f32::INFINITY;
for c in 0..num_centroids {
let cent_offset = c * self.subspace_dims;
let centroid = &self.pq_centroids[m][cent_offset..cent_offset + self.subspace_dims];
let mut dist = 0.0;
for d in 0..self.subspace_dims {
let diff = subvec[d] - centroid[d];
dist += diff * diff;
}
if dist < best_dist {
best_dist = dist;
best_centroid = c;
}
}
*code = best_centroid as u8;
}
codes
}
pub fn delete(&mut self, vector_id: u64, vector: &[f32]) -> bool {
if !self.trained {
return false;
}
let distance_fn = self.config.ivf.metric.distance_fn();
let query_vec: Cow<[f32]> = if self.config.ivf.metric == DistanceMetric::Cosine {
Cow::Owned(normalize(vector))
} else {
Cow::Borrowed(vector)
};
let query_slice = query_vec.as_ref();
let mut best_cluster = 0;
let mut best_dist = f32::INFINITY;
for c in 0..self.config.ivf.n_clusters {
let cent_offset = c * self.dimensions;
let centroid = &self.ivf_centroids[cent_offset..cent_offset + self.dimensions];
let dist = distance_fn(query_slice, centroid);
if dist < best_dist {
best_dist = dist;
best_cluster = c;
}
}
let removed_from_list = if let Some(list) = self.inverted_lists.get_mut(&best_cluster) {
if let Some(idx) = list.iter().position(|&id| id == vector_id) {
list.swap_remove(idx);
true
} else {
false
}
} else {
false
};
let removed_codes = self.pq_codes.remove(&vector_id).is_some();
removed_from_list || removed_codes
}
pub fn search(
&self,
manifest: &VectorManifest,
query: &[f32],
k: usize,
options: Option<IvfPqSearchOptions>,
) -> Vec<VectorSearchResult> {
if !self.trained {
return Vec::new();
}
let options = options.unwrap_or_default();
let n_probe = options.n_probe.unwrap_or(self.config.ivf.n_probe);
let query_for_search: Cow<[f32]> = if self.config.ivf.metric == DistanceMetric::Cosine {
Cow::Owned(normalize(query))
} else {
Cow::Borrowed(query)
};
let query_slice = query_for_search.as_ref();
let probe_clusters = self.find_nearest_centroids(query_slice, n_probe);
let mut heap = MaxHeap::new();
let shared_dist_table = if !self.config.use_residuals {
Some(self.build_distance_table(query_slice))
} else {
None
};
let shared_table = if self.config.use_residuals {
None
} else {
match shared_dist_table.as_deref() {
Some(table) => Some(table),
None => {
debug_assert!(
false,
"shared distance table missing for non-residual search"
);
return Vec::new();
}
}
};
let mut search_vectors = |dist_table: &[f32], vector_ids: &Vec<u64>| {
for &vector_id in vector_ids {
if let Some(ref filter) = options.filter {
if let Some(&node_id) = manifest.vector_to_node.get(&vector_id) {
if !filter(node_id) {
continue;
}
}
}
let codes = match self.pq_codes.get(&vector_id) {
Some(c) => c,
None => continue,
};
let dist = self.distance_adc(dist_table, codes);
if let Some(threshold) = options.threshold {
let similarity = self.config.ivf.metric.distance_to_similarity(dist);
if similarity < threshold {
continue;
}
}
if heap.len() < k {
heap.push(vector_id, dist);
} else if let Some(&(_, max_dist)) = heap.peek() {
if dist < max_dist {
heap.pop();
heap.push(vector_id, dist);
}
}
}
};
for cluster in probe_clusters {
let vector_ids = match self.inverted_lists.get(&cluster) {
Some(list) if !list.is_empty() => list,
_ => continue,
};
if self.config.use_residuals {
let cent_offset = cluster * self.dimensions;
let query_residual: Vec<f32> = query_slice
.iter()
.zip(&self.ivf_centroids[cent_offset..cent_offset + self.dimensions])
.map(|(q, c)| q - c)
.collect();
let dist_table = self.build_distance_table(&query_residual);
search_vectors(&dist_table, vector_ids);
} else if let Some(table) = shared_table {
search_vectors(table, vector_ids);
} else {
debug_assert!(
false,
"shared distance table missing for non-residual search"
);
return Vec::new();
}
}
let results = heap.into_sorted_vec();
results
.into_iter()
.map(|(vector_id, distance)| {
let node_id = manifest
.vector_to_node
.get(&vector_id)
.copied()
.unwrap_or(0);
VectorSearchResult {
vector_id,
node_id,
distance,
similarity: self.config.ivf.metric.distance_to_similarity(distance),
}
})
.collect()
}
fn build_distance_table(&self, query: &[f32]) -> Vec<f32> {
let num_subspaces = self.config.pq.num_subspaces;
let num_centroids = self.config.pq.num_centroids;
let mut table = vec![0.0; num_subspaces * num_centroids];
for m in 0..num_subspaces {
let sub_offset = m * self.subspace_dims;
let table_offset = m * num_centroids;
let query_sub = &query[sub_offset..sub_offset + self.subspace_dims];
for c in 0..num_centroids {
let cent_offset = c * self.subspace_dims;
let centroid = &self.pq_centroids[m][cent_offset..cent_offset + self.subspace_dims];
let mut dist = 0.0;
for d in 0..self.subspace_dims {
let diff = query_sub[d] - centroid[d];
dist += diff * diff;
}
table[table_offset + c] = dist;
}
}
table
}
fn distance_adc(&self, table: &[f32], codes: &[u8]) -> f32 {
let num_subspaces = self.config.pq.num_subspaces;
let num_centroids = self.config.pq.num_centroids;
let mut dist = 0.0;
let remainder = num_subspaces % 8;
let main_len = num_subspaces - remainder;
for m in (0..main_len).step_by(8) {
dist += table[m * num_centroids + codes[m] as usize]
+ table[(m + 1) * num_centroids + codes[m + 1] as usize]
+ table[(m + 2) * num_centroids + codes[m + 2] as usize]
+ table[(m + 3) * num_centroids + codes[m + 3] as usize]
+ table[(m + 4) * num_centroids + codes[m + 4] as usize]
+ table[(m + 5) * num_centroids + codes[m + 5] as usize]
+ table[(m + 6) * num_centroids + codes[m + 6] as usize]
+ table[(m + 7) * num_centroids + codes[m + 7] as usize];
}
for m in main_len..num_subspaces {
dist += table[m * num_centroids + codes[m] as usize];
}
dist
}
fn find_nearest_centroids(&self, query: &[f32], n: usize) -> Vec<usize> {
let distance_fn = self.config.ivf.metric.distance_fn();
let n_clusters = self.config.ivf.n_clusters;
if n == 0 {
return Vec::new();
}
let mut centroid_dists: Vec<(usize, f32)> = (0..n_clusters)
.map(|c| {
let cent_offset = c * self.dimensions;
let centroid = &self.ivf_centroids[cent_offset..cent_offset + self.dimensions];
let dist = distance_fn(query, centroid);
(c, dist)
})
.collect();
if n >= n_clusters {
centroid_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
return centroid_dists.into_iter().map(|(c, _)| c).collect();
}
centroid_dists.select_nth_unstable_by(n - 1, |a, b| {
a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
});
centroid_dists[..n].sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
centroid_dists[..n].iter().map(|(c, _)| *c).collect()
}
pub fn search_multi(
&self,
manifest: &VectorManifest,
queries: &[&[f32]],
k: usize,
aggregation: MultiQueryAggregation,
options: Option<IvfPqSearchOptions>,
) -> Vec<VectorSearchResult> {
if !self.trained || queries.is_empty() {
return Vec::new();
}
let options = options.unwrap_or_default();
let expanded_k = k * 2;
let all_results: Vec<Vec<VectorSearchResult>> = queries
.iter()
.map(|query| self.search(manifest, query, expanded_k, None))
.collect();
let mut aggregated: std::collections::HashMap<NodeId, (Vec<f32>, u64)> =
std::collections::HashMap::new();
for results in &all_results {
for result in results {
let entry = aggregated
.entry(result.node_id)
.or_insert_with(|| (Vec::new(), result.vector_id));
entry.0.push(result.distance);
}
}
let aggregated: std::collections::HashMap<NodeId, (Vec<f32>, u64)> =
if let Some(ref filter) = options.filter {
aggregated
.into_iter()
.filter(|(node_id, _)| filter(*node_id))
.collect()
} else {
aggregated
};
let mut scored: Vec<VectorSearchResult> = aggregated
.into_iter()
.map(|(node_id, (distances, vector_id))| {
let distance = aggregation.aggregate(&distances);
let similarity = self.config.ivf.metric.distance_to_similarity(distance);
VectorSearchResult {
vector_id,
node_id,
distance,
similarity,
}
})
.collect();
if let Some(threshold) = options.threshold {
scored.retain(|r| r.similarity >= threshold);
}
scored.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
scored.truncate(k);
scored
}
pub fn build_from_store(&mut self, manifest: &VectorManifest) -> Result<(), IvfPqError> {
for fragment in &manifest.fragments {
for row_group in &fragment.row_groups {
self.add_training_vectors(&row_group.data, row_group.count)?;
}
}
self.train()?;
let fragment_map: std::collections::HashMap<usize, &_> =
manifest.fragments.iter().map(|f| (f.id, f)).collect();
for (&vector_id, location) in &manifest.vector_locations {
let fragment = match fragment_map.get(&location.fragment_id) {
Some(f) => *f,
None => continue,
};
if fragment.is_deleted(location.local_index) {
continue;
}
let row_group_idx = location.local_index / manifest.config.row_group_size;
let local_row_idx = location.local_index % manifest.config.row_group_size;
let row_group = match fragment.row_groups.get(row_group_idx) {
Some(rg) => rg,
None => continue,
};
let offset = local_row_idx * manifest.config.dimensions;
let vector = &row_group.data[offset..offset + manifest.config.dimensions];
self.insert(vector_id, vector)?;
}
Ok(())
}
pub fn stats(&self) -> IvfPqStats {
let mut total_vectors = 0;
let mut empty_clusters = 0;
let mut min_cluster_size = usize::MAX;
let mut max_cluster_size = 0;
for list in self.inverted_lists.values() {
total_vectors += list.len();
if list.is_empty() {
empty_clusters += 1;
}
min_cluster_size = min_cluster_size.min(list.len());
max_cluster_size = max_cluster_size.max(list.len());
}
if self.inverted_lists.is_empty() {
min_cluster_size = 0;
}
let original_bytes = total_vectors * self.dimensions * 4; let pq_code_bytes = total_vectors * self.config.pq.num_subspaces; let pq_centroid_bytes =
self.config.pq.num_subspaces * self.config.pq.num_centroids * self.subspace_dims * 4;
let ivf_centroid_bytes = self.config.ivf.n_clusters * self.dimensions * 4;
let compressed_bytes = pq_code_bytes + pq_centroid_bytes + ivf_centroid_bytes;
let memory_savings_ratio = if original_bytes > 0 {
original_bytes as f32 / compressed_bytes as f32
} else {
0.0
};
IvfPqStats {
trained: self.trained,
n_clusters: self.config.ivf.n_clusters,
total_vectors,
avg_vectors_per_cluster: if self.config.ivf.n_clusters > 0 {
total_vectors as f32 / self.config.ivf.n_clusters as f32
} else {
0.0
},
empty_cluster_count: empty_clusters,
min_cluster_size,
max_cluster_size,
pq_num_subspaces: self.config.pq.num_subspaces,
pq_num_centroids: self.config.pq.num_centroids,
memory_savings_ratio,
}
}
pub fn clear(&mut self) {
self.ivf_centroids.clear();
self.inverted_lists.clear();
self.pq_codes.clear();
self.centroid_distances = None;
self.trained = false;
self.training_vectors = Some(Vec::new());
self.training_count = 0;
for centroids in &mut self.pq_centroids {
centroids.fill(0.0);
}
}
}
#[derive(Default)]
pub struct IvfPqSearchOptions {
pub n_probe: Option<usize>,
pub filter: Option<Box<dyn Fn(NodeId) -> bool>>,
pub threshold: Option<f32>,
}
impl std::fmt::Debug for IvfPqSearchOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IvfPqSearchOptions")
.field("n_probe", &self.n_probe)
.field("filter", &self.filter.as_ref().map(|_| "<fn>"))
.field("threshold", &self.threshold)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct IvfPqStats {
pub trained: bool,
pub n_clusters: usize,
pub total_vectors: usize,
pub avg_vectors_per_cluster: f32,
pub empty_cluster_count: usize,
pub min_cluster_size: usize,
pub max_cluster_size: usize,
pub pq_num_subspaces: usize,
pub pq_num_centroids: usize,
pub memory_savings_ratio: f32,
}
struct MaxHeap {
items: Vec<(u64, f32)>, }
impl MaxHeap {
fn new() -> Self {
Self { items: Vec::new() }
}
fn len(&self) -> usize {
self.items.len()
}
fn push(&mut self, id: u64, dist: f32) {
self.items.push((id, dist));
self.sift_up(self.items.len() - 1);
}
fn pop(&mut self) -> Option<(u64, f32)> {
if self.items.is_empty() {
return None;
}
let len = self.items.len();
self.items.swap(0, len - 1);
let result = self.items.pop();
if !self.items.is_empty() {
self.sift_down(0);
}
result
}
fn peek(&self) -> Option<&(u64, f32)> {
self.items.first()
}
fn sift_up(&mut self, mut idx: usize) {
while idx > 0 {
let parent = (idx - 1) / 2;
if self.items[idx].1 > self.items[parent].1 {
self.items.swap(idx, parent);
idx = parent;
} else {
break;
}
}
}
fn sift_down(&mut self, mut idx: usize) {
let len = self.items.len();
loop {
let left = 2 * idx + 1;
let right = 2 * idx + 2;
let mut largest = idx;
if left < len && self.items[left].1 > self.items[largest].1 {
largest = left;
}
if right < len && self.items[right].1 > self.items[largest].1 {
largest = right;
}
if largest != idx {
self.items.swap(idx, largest);
idx = largest;
} else {
break;
}
}
}
fn into_sorted_vec(mut self) -> Vec<(u64, f32)> {
let mut result = Vec::with_capacity(self.items.len());
while let Some(item) = self.pop() {
result.push(item);
}
result.reverse();
result
}
}
fn train_pq_subspace(
centroids: &mut [f32],
subvectors: &[f32],
num_vectors: usize,
subspace_dims: usize,
num_centroids: usize,
max_iterations: usize,
) {
initialize_pq_centroids_kmeans_pp(
centroids,
subvectors,
num_vectors,
subspace_dims,
num_centroids,
);
let mut assignments = vec![0u16; num_vectors];
let mut cluster_sums = vec![0.0f32; num_centroids * subspace_dims];
let mut cluster_counts = vec![0u32; num_centroids];
for _ in 0..max_iterations {
for (i, assignment) in assignments.iter_mut().enumerate().take(num_vectors) {
let vec_offset = i * subspace_dims;
let mut best_centroid = 0;
let mut best_dist = f32::INFINITY;
for c in 0..num_centroids {
let cent_offset = c * subspace_dims;
let mut dist = 0.0;
for d in 0..subspace_dims {
let diff = subvectors[vec_offset + d] - centroids[cent_offset + d];
dist += diff * diff;
}
if dist < best_dist {
best_dist = dist;
best_centroid = c;
}
}
*assignment = best_centroid as u16;
}
cluster_sums.fill(0.0);
cluster_counts.fill(0);
for (i, &cluster_id) in assignments.iter().enumerate().take(num_vectors) {
let cluster = cluster_id as usize;
let vec_offset = i * subspace_dims;
let sum_offset = cluster * subspace_dims;
for d in 0..subspace_dims {
cluster_sums[sum_offset + d] += subvectors[vec_offset + d];
}
cluster_counts[cluster] += 1;
}
for (c, &count) in cluster_counts.iter().enumerate() {
if count == 0 {
continue;
}
let offset = c * subspace_dims;
for d in 0..subspace_dims {
centroids[offset + d] = cluster_sums[offset + d] / count as f32;
}
}
}
}
fn initialize_pq_centroids_kmeans_pp(
centroids: &mut [f32],
vectors: &[f32],
num_vectors: usize,
dims: usize,
k: usize,
) {
use rand::Rng;
let mut rng = rand::thread_rng();
let first_idx = rng.gen_range(0..num_vectors);
for d in 0..dims {
centroids[d] = vectors[first_idx * dims + d];
}
let mut min_dists = vec![f32::INFINITY; num_vectors];
for c in 1..k {
let prev_cent_offset = (c - 1) * dims;
let mut total_dist = 0.0;
for (i, min_dist) in min_dists.iter_mut().enumerate().take(num_vectors) {
let vec_offset = i * dims;
let mut dist = 0.0;
for d in 0..dims {
let diff = vectors[vec_offset + d] - centroids[prev_cent_offset + d];
dist += diff * diff;
}
*min_dist = (*min_dist).min(dist);
total_dist += *min_dist;
}
let mut r = rng.gen::<f32>() * total_dist;
let mut selected_idx = 0;
for (i, dist) in min_dists.iter().enumerate().take(num_vectors) {
r -= *dist;
if r <= 0.0 {
selected_idx = i;
break;
}
}
let cent_offset = c * dims;
for d in 0..dims {
centroids[cent_offset + d] = vectors[selected_idx * dims + d];
}
}
}
#[derive(Debug, Clone)]
pub enum IvfPqError {
DimensionNotDivisible {
dimensions: usize,
num_subspaces: usize,
},
DimensionMismatch {
expected: usize,
got: usize,
},
AlreadyTrained,
NotTrained,
NoTrainingVectors,
NotEnoughTrainingVectors {
n: usize,
k: usize,
},
TrainingFailed(String),
}
impl std::fmt::Display for IvfPqError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
IvfPqError::DimensionNotDivisible {
dimensions,
num_subspaces,
} => write!(
f,
"Dimensions ({dimensions}) must be divisible by num_subspaces ({num_subspaces})"
),
IvfPqError::DimensionMismatch { expected, got } => {
write!(f, "Dimension mismatch: expected {expected}, got {got}")
}
IvfPqError::AlreadyTrained => write!(f, "Index already trained"),
IvfPqError::NotTrained => write!(f, "Index not trained"),
IvfPqError::NoTrainingVectors => write!(f, "No training vectors provided"),
IvfPqError::NotEnoughTrainingVectors { n, k } => {
write!(f, "Not enough training vectors: {n} < {k} required")
}
IvfPqError::TrainingFailed(msg) => write!(f, "Training failed: {msg}"),
}
}
}
impl std::error::Error for IvfPqError {}
const IVFPQ_MAGIC: u32 = 0x49565051;
const IVFPQ_HEADER_SIZE: usize = 48;
#[derive(Debug, Clone)]
pub enum SerializeError {
InvalidMagic { expected: u32, got: u32 },
BufferUnderflow {
context: String,
offset: usize,
needed: usize,
available: usize,
},
InvalidMetric(u32),
}
impl std::fmt::Display for SerializeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SerializeError::InvalidMagic { expected, got } => {
write!(
f,
"Invalid magic: expected 0x{expected:08X}, got 0x{got:08X}"
)
}
SerializeError::BufferUnderflow {
context,
offset,
needed,
available,
} => {
write!(
f,
"Buffer underflow in {context}: need {needed} bytes at offset {offset}, but only {available} available"
)
}
SerializeError::InvalidMetric(n) => {
write!(
f,
"Invalid metric value: {n}. Expected 0 (cosine), 1 (euclidean), or 2 (dot)"
)
}
}
}
}
impl std::error::Error for SerializeError {}
fn metric_to_u8(metric: DistanceMetric) -> u8 {
match metric {
DistanceMetric::Cosine => 0,
DistanceMetric::Euclidean => 1,
DistanceMetric::DotProduct => 2,
}
}
fn u8_to_metric(n: u8) -> Result<DistanceMetric, SerializeError> {
match n {
0 => Ok(DistanceMetric::Cosine),
1 => Ok(DistanceMetric::Euclidean),
2 => Ok(DistanceMetric::DotProduct),
_ => Err(SerializeError::InvalidMetric(n as u32)),
}
}
fn ensure_bytes(
buf_len: usize,
offset: usize,
needed: usize,
context: &str,
) -> Result<(), SerializeError> {
if offset + needed > buf_len {
return Err(SerializeError::BufferUnderflow {
context: context.to_string(),
offset,
needed,
available: buf_len.saturating_sub(offset),
});
}
Ok(())
}
fn read_u8(buffer: &[u8], offset: &mut usize, context: &str) -> Result<u8, SerializeError> {
let end = offset.saturating_add(1);
let slice = buffer
.get(*offset..end)
.ok_or_else(|| SerializeError::BufferUnderflow {
context: context.to_string(),
offset: *offset,
needed: 1,
available: buffer.len().saturating_sub(*offset),
})?;
*offset = end;
Ok(slice[0])
}
fn read_u32_le(buffer: &[u8], offset: &mut usize, context: &str) -> Result<u32, SerializeError> {
let end = offset.saturating_add(4);
let slice = buffer
.get(*offset..end)
.ok_or_else(|| SerializeError::BufferUnderflow {
context: context.to_string(),
offset: *offset,
needed: 4,
available: buffer.len().saturating_sub(*offset),
})?;
let mut bytes = [0u8; 4];
bytes.copy_from_slice(slice);
*offset = end;
Ok(u32::from_le_bytes(bytes))
}
fn read_u64_le(buffer: &[u8], offset: &mut usize, context: &str) -> Result<u64, SerializeError> {
let end = offset.saturating_add(8);
let slice = buffer
.get(*offset..end)
.ok_or_else(|| SerializeError::BufferUnderflow {
context: context.to_string(),
offset: *offset,
needed: 8,
available: buffer.len().saturating_sub(*offset),
})?;
let mut bytes = [0u8; 8];
bytes.copy_from_slice(slice);
*offset = end;
Ok(u64::from_le_bytes(bytes))
}
fn read_f32_le(buffer: &[u8], offset: &mut usize, context: &str) -> Result<f32, SerializeError> {
let end = offset.saturating_add(4);
let slice = buffer
.get(*offset..end)
.ok_or_else(|| SerializeError::BufferUnderflow {
context: context.to_string(),
offset: *offset,
needed: 4,
available: buffer.len().saturating_sub(*offset),
})?;
let mut bytes = [0u8; 4];
bytes.copy_from_slice(slice);
*offset = end;
Ok(f32::from_le_bytes(bytes))
}
pub fn ivf_pq_serialized_size(index: &IvfPqIndex) -> usize {
let mut size = IVFPQ_HEADER_SIZE;
size += 4 + index.ivf_centroids.len() * 4;
size += 4;
for list in index.inverted_lists.values() {
size += 4 + 4 + list.len() * 8; }
size += 4; for centroids in &index.pq_centroids {
size += 4 + centroids.len() * 4; }
size += 4; for codes in index.pq_codes.values() {
size += 8 + 4 + codes.len(); }
size += 1; if let Some(ref dists) = index.centroid_distances {
size += 4 + dists.len() * 4; }
size
}
pub fn serialize_ivf_pq(index: &IvfPqIndex) -> Vec<u8> {
let size = ivf_pq_serialized_size(index);
let mut buffer = Vec::with_capacity(size);
buffer.extend_from_slice(&IVFPQ_MAGIC.to_le_bytes());
buffer.extend_from_slice(&(index.dimensions as u32).to_le_bytes());
buffer.extend_from_slice(&(index.config.ivf.n_clusters as u32).to_le_bytes());
buffer.extend_from_slice(&(index.config.ivf.n_probe as u32).to_le_bytes());
buffer.extend_from_slice(&(index.config.pq.num_subspaces as u32).to_le_bytes());
buffer.extend_from_slice(&(index.config.pq.num_centroids as u32).to_le_bytes());
buffer.extend_from_slice(&(index.config.pq.max_iterations as u32).to_le_bytes());
buffer.push(metric_to_u8(index.config.ivf.metric));
buffer.push(if index.trained { 1 } else { 0 });
buffer.push(if index.config.use_residuals { 1 } else { 0 });
buffer.extend_from_slice(&[0u8; 17]);
buffer.extend_from_slice(&(index.ivf_centroids.len() as u32).to_le_bytes());
for &val in &index.ivf_centroids {
buffer.extend_from_slice(&val.to_le_bytes());
}
buffer.extend_from_slice(&(index.inverted_lists.len() as u32).to_le_bytes());
for (&cluster, list) in &index.inverted_lists {
buffer.extend_from_slice(&(cluster as u32).to_le_bytes());
buffer.extend_from_slice(&(list.len() as u32).to_le_bytes());
for &vector_id in list {
buffer.extend_from_slice(&vector_id.to_le_bytes());
}
}
buffer.extend_from_slice(&(index.pq_centroids.len() as u32).to_le_bytes());
for centroids in &index.pq_centroids {
buffer.extend_from_slice(&(centroids.len() as u32).to_le_bytes());
for &val in centroids {
buffer.extend_from_slice(&val.to_le_bytes());
}
}
buffer.extend_from_slice(&(index.pq_codes.len() as u32).to_le_bytes());
for (&vector_id, codes) in &index.pq_codes {
buffer.extend_from_slice(&vector_id.to_le_bytes());
buffer.extend_from_slice(&(codes.len() as u32).to_le_bytes());
buffer.extend_from_slice(codes);
}
if let Some(ref dists) = index.centroid_distances {
buffer.push(1);
buffer.extend_from_slice(&(dists.len() as u32).to_le_bytes());
for &val in dists {
buffer.extend_from_slice(&val.to_le_bytes());
}
} else {
buffer.push(0);
}
buffer
}
pub fn deserialize_ivf_pq(buffer: &[u8]) -> Result<IvfPqIndex, SerializeError> {
let buf_len = buffer.len();
ensure_bytes(buf_len, 0, IVFPQ_HEADER_SIZE, "IVF-PQ header")?;
let mut offset = 0;
let magic = read_u32_le(buffer, &mut offset, "IVF-PQ magic")?;
if magic != IVFPQ_MAGIC {
return Err(SerializeError::InvalidMagic {
expected: IVFPQ_MAGIC,
got: magic,
});
}
let dimensions = read_u32_le(buffer, &mut offset, "IVF-PQ dimensions")? as usize;
let n_clusters = read_u32_le(buffer, &mut offset, "IVF-PQ n_clusters")? as usize;
let n_probe = read_u32_le(buffer, &mut offset, "IVF-PQ n_probe")? as usize;
let num_subspaces = read_u32_le(buffer, &mut offset, "IVF-PQ num_subspaces")? as usize;
let num_centroids = read_u32_le(buffer, &mut offset, "IVF-PQ num_centroids")? as usize;
let max_iterations = read_u32_le(buffer, &mut offset, "IVF-PQ max_iterations")? as usize;
let metric = u8_to_metric(read_u8(buffer, &mut offset, "IVF-PQ metric")?)?;
let trained = read_u8(buffer, &mut offset, "IVF-PQ trained")? == 1;
let use_residuals = read_u8(buffer, &mut offset, "IVF-PQ use_residuals")? == 1;
ensure_bytes(buf_len, offset, 17, "IVF-PQ header reserved")?;
offset += 17;
let config = IvfPqConfig {
ivf: IvfConfig {
n_clusters,
n_probe,
metric,
},
pq: PqConfig {
num_subspaces,
num_centroids,
max_iterations,
},
use_residuals,
};
let ivf_centroid_count = read_u32_le(buffer, &mut offset, "IVF-PQ centroid count")? as usize;
ensure_bytes(
buf_len,
offset,
ivf_centroid_count * 4,
"IVF-PQ IVF centroids",
)?;
let mut ivf_centroids = Vec::with_capacity(ivf_centroid_count);
for _ in 0..ivf_centroid_count {
let val = read_f32_le(buffer, &mut offset, "IVF-PQ IVF centroid")?;
ivf_centroids.push(val);
}
let num_lists = read_u32_le(buffer, &mut offset, "IVF-PQ inverted list count")? as usize;
let mut inverted_lists: HashMap<usize, Vec<u64>> = HashMap::new();
for i in 0..num_lists {
ensure_bytes(
buf_len,
offset,
8,
&format!("IVF-PQ inverted list {i} header"),
)?;
let cluster = read_u32_le(
buffer,
&mut offset,
&format!("IVF-PQ inverted list {i} cluster"),
)? as usize;
let list_length = read_u32_le(
buffer,
&mut offset,
&format!("IVF-PQ inverted list {i} length"),
)? as usize;
ensure_bytes(
buf_len,
offset,
list_length * 8,
&format!("IVF-PQ inverted list {i} data"),
)?;
let mut list = Vec::with_capacity(list_length);
for _ in 0..list_length {
let vector_id = read_u64_le(buffer, &mut offset, "IVF-PQ inverted list vector_id")?;
list.push(vector_id);
}
inverted_lists.insert(cluster, list);
}
let num_pq_subspaces = read_u32_le(buffer, &mut offset, "IVF-PQ PQ subspace count")? as usize;
let mut pq_centroids = Vec::with_capacity(num_pq_subspaces);
for i in 0..num_pq_subspaces {
ensure_bytes(
buf_len,
offset,
4,
&format!("IVF-PQ PQ subspace {i} centroid count"),
)?;
let centroid_count = read_u32_le(
buffer,
&mut offset,
&format!("IVF-PQ PQ subspace {i} centroid count"),
)? as usize;
ensure_bytes(
buf_len,
offset,
centroid_count * 4,
&format!("IVF-PQ PQ subspace {i} centroids"),
)?;
let mut centroids = Vec::with_capacity(centroid_count);
for _ in 0..centroid_count {
let val = read_f32_le(buffer, &mut offset, "IVF-PQ PQ centroid")?;
centroids.push(val);
}
pq_centroids.push(centroids);
}
let num_pq_codes = read_u32_le(buffer, &mut offset, "IVF-PQ PQ codes count")? as usize;
let mut pq_codes: HashMap<u64, Vec<u8>> = HashMap::new();
for i in 0..num_pq_codes {
ensure_bytes(buf_len, offset, 12, &format!("IVF-PQ PQ code {i} header"))?;
let vector_id = read_u64_le(
buffer,
&mut offset,
&format!("IVF-PQ PQ code {i} vector_id"),
)?;
let code_len =
read_u32_le(buffer, &mut offset, &format!("IVF-PQ PQ code {i} length"))? as usize;
ensure_bytes(
buf_len,
offset,
code_len,
&format!("IVF-PQ PQ code {i} data"),
)?;
let codes = buffer[offset..offset + code_len].to_vec();
offset += code_len;
pq_codes.insert(vector_id, codes);
}
let has_centroid_distances = read_u8(buffer, &mut offset, "IVF-PQ centroid distances flag")? == 1;
let centroid_distances = if has_centroid_distances {
let distance_count =
read_u32_le(buffer, &mut offset, "IVF-PQ centroid distances count")? as usize;
ensure_bytes(
buf_len,
offset,
distance_count * 4,
"IVF-PQ centroid distances",
)?;
let mut dists = Vec::with_capacity(distance_count);
for _ in 0..distance_count {
let val = read_f32_le(buffer, &mut offset, "IVF-PQ centroid distance")?;
dists.push(val);
}
Some(dists)
} else {
None
};
IvfPqIndex::from_serialized(
config,
ivf_centroids,
inverted_lists,
pq_codes,
pq_centroids,
centroid_distances,
dimensions,
trained,
)
.map_err(|e| SerializeError::BufferUnderflow {
context: format!("IVF-PQ index construction: {e}"),
offset: 0,
needed: 0,
available: 0,
})
}
pub fn write_ivf_pq<W: std::io::Write>(
index: &IvfPqIndex,
writer: &mut W,
) -> std::io::Result<usize> {
let data = serialize_ivf_pq(index);
writer.write_all(&data)?;
Ok(data.len())
}
pub fn read_ivf_pq<R: std::io::Read>(reader: &mut R) -> Result<IvfPqIndex, SerializeError> {
let mut buffer = Vec::new();
reader
.read_to_end(&mut buffer)
.map_err(|e| SerializeError::BufferUnderflow {
context: format!("IO error: {e}"),
offset: 0,
needed: 0,
available: 0,
})?;
deserialize_ivf_pq(&buffer)
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> IvfPqConfig {
IvfPqConfig {
ivf: IvfConfig {
n_clusters: 4,
n_probe: 2,
metric: DistanceMetric::Euclidean,
},
pq: PqConfig {
num_subspaces: 4,
num_centroids: 8,
max_iterations: 10,
},
use_residuals: true,
}
}
#[test]
fn test_ivf_pq_new() {
let index = IvfPqIndex::new(16, test_config()).expect("expected value");
assert_eq!(index.dimensions, 16);
assert_eq!(index.subspace_dims, 4);
assert!(!index.trained);
}
#[test]
fn test_ivf_pq_new_not_divisible() {
let result = IvfPqIndex::new(15, test_config());
assert!(matches!(
result,
Err(IvfPqError::DimensionNotDivisible { .. })
));
}
#[test]
fn test_ivf_pq_add_training_vectors() {
let mut index = IvfPqIndex::new(16, test_config()).expect("expected value");
let vectors = vec![0.0f32; 50 * 16];
index
.add_training_vectors(&vectors, 50)
.expect("expected value");
assert_eq!(index.training_count, 50);
}
#[test]
fn test_ivf_pq_train() {
let mut index = IvfPqIndex::new(16, test_config()).expect("expected value");
let mut vectors = Vec::new();
for i in 0..500 {
for d in 0..16 {
vectors.push((i * 16 + d) as f32 / 8000.0);
}
}
index
.add_training_vectors(&vectors, 500)
.expect("expected value");
index.train().expect("expected value");
assert!(index.trained);
assert_eq!(index.ivf_centroids.len(), 4 * 16); assert!(index.centroid_distances.is_some());
}
#[test]
fn test_ivf_pq_train_not_enough_vectors() {
let mut index = IvfPqIndex::new(16, test_config()).expect("expected value");
let vectors = vec![0.0f32; 2 * 16]; index
.add_training_vectors(&vectors, 2)
.expect("expected value");
let result = index.train();
assert!(matches!(
result,
Err(IvfPqError::NotEnoughTrainingVectors { .. })
));
}
#[test]
fn test_ivf_pq_insert() {
let mut index = IvfPqIndex::new(16, test_config()).expect("expected value");
let mut vectors = Vec::new();
for i in 0..500 {
for d in 0..16 {
vectors.push((i * 16 + d) as f32 / 8000.0);
}
}
index
.add_training_vectors(&vectors, 500)
.expect("expected value");
index.train().expect("expected value");
let vector = vec![0.5f32; 16];
index.insert(0, &vector).expect("expected value");
let stats = index.stats();
assert_eq!(stats.total_vectors, 1);
assert!(index.pq_codes.contains_key(&0));
}
#[test]
fn test_ivf_pq_insert_not_trained() {
let mut index = IvfPqIndex::new(16, test_config()).expect("expected value");
let vector = vec![0.5f32; 16];
let result = index.insert(0, &vector);
assert!(matches!(result, Err(IvfPqError::NotTrained)));
}
#[test]
fn test_ivf_pq_delete() {
let mut index = IvfPqIndex::new(16, test_config()).expect("expected value");
let mut vectors = Vec::new();
for i in 0..500 {
for d in 0..16 {
vectors.push((i * 16 + d) as f32 / 8000.0);
}
}
index
.add_training_vectors(&vectors, 500)
.expect("expected value");
index.train().expect("expected value");
let vector = vec![0.5f32; 16];
index.insert(0, &vector).expect("expected value");
assert!(index.delete(0, &vector));
assert!(!index.delete(0, &vector));
let stats = index.stats();
assert_eq!(stats.total_vectors, 0);
}
#[test]
fn test_ivf_pq_stats() {
let mut index = IvfPqIndex::new(16, test_config()).expect("expected value");
let mut vectors = Vec::new();
for i in 0..500 {
for d in 0..16 {
vectors.push((i * 16 + d) as f32 / 8000.0);
}
}
index
.add_training_vectors(&vectors, 500)
.expect("expected value");
index.train().expect("expected value");
for i in 0..10 {
let vector: Vec<f32> = (0..16).map(|d| (i * 16 + d) as f32 / 160.0).collect();
index.insert(i as u64, &vector).expect("expected value");
}
let stats = index.stats();
assert!(stats.trained);
assert_eq!(stats.n_clusters, 4);
assert_eq!(stats.total_vectors, 10);
assert_eq!(stats.pq_num_subspaces, 4);
assert_eq!(stats.pq_num_centroids, 8);
assert!(stats.memory_savings_ratio > 0.0);
}
#[test]
fn test_ivf_pq_clear() {
let mut index = IvfPqIndex::new(16, test_config()).expect("expected value");
let mut vectors = Vec::new();
for i in 0..500 {
for d in 0..16 {
vectors.push((i * 16 + d) as f32 / 8000.0);
}
}
index
.add_training_vectors(&vectors, 500)
.expect("expected value");
index.train().expect("expected value");
let vector = vec![0.5f32; 16];
index.insert(0, &vector).expect("expected value");
index.clear();
assert!(!index.trained);
assert!(index.ivf_centroids.is_empty());
assert!(index.inverted_lists.is_empty());
assert!(index.pq_codes.is_empty());
}
#[test]
fn test_ivf_pq_config_builder() {
let config = IvfPqConfig::new()
.with_n_clusters(50)
.with_n_probe(5)
.with_metric(DistanceMetric::Euclidean)
.with_num_subspaces(32)
.with_num_centroids(128)
.with_residuals(false);
assert_eq!(config.ivf.n_clusters, 50);
assert_eq!(config.ivf.n_probe, 5);
assert_eq!(config.ivf.metric, DistanceMetric::Euclidean);
assert_eq!(config.pq.num_subspaces, 32);
assert_eq!(config.pq.num_centroids, 128);
assert!(!config.use_residuals);
}
#[test]
fn test_ivf_pq_without_residuals() {
let config = IvfPqConfig {
ivf: IvfConfig {
n_clusters: 4,
n_probe: 2,
metric: DistanceMetric::Euclidean,
},
pq: PqConfig {
num_subspaces: 4,
num_centroids: 8,
max_iterations: 10,
},
use_residuals: false, };
let mut index = IvfPqIndex::new(16, config).expect("expected value");
let mut vectors = Vec::new();
for i in 0..500 {
for d in 0..16 {
vectors.push((i * 16 + d) as f32 / 8000.0);
}
}
index
.add_training_vectors(&vectors, 500)
.expect("expected value");
index.train().expect("expected value");
assert!(index.trained);
}
#[test]
fn test_max_heap() {
let mut heap = MaxHeap::new();
heap.push(1, 0.5);
heap.push(2, 0.3);
heap.push(3, 0.8);
heap.push(4, 0.1);
assert_eq!(heap.len(), 4);
let (id, dist) = *heap.peek().expect("expected value");
assert_eq!(id, 3);
assert_eq!(dist, 0.8);
let sorted = heap.into_sorted_vec();
assert_eq!(sorted.len(), 4);
assert!(sorted[0].1 <= sorted[1].1);
assert!(sorted[1].1 <= sorted[2].1);
assert!(sorted[2].1 <= sorted[3].1);
}
#[test]
fn test_error_display() {
let err1 = IvfPqError::DimensionNotDivisible {
dimensions: 15,
num_subspaces: 4,
};
assert!(err1.to_string().contains("15"));
assert!(err1.to_string().contains("4"));
let err2 = IvfPqError::AlreadyTrained;
assert!(err2.to_string().contains("already"));
let err3 = IvfPqError::NotTrained;
assert!(err3.to_string().contains("not trained"));
}
#[test]
fn test_ivf_pq_serialize_empty() {
let index = IvfPqIndex::new(16, test_config()).expect("expected value");
let serialized = serialize_ivf_pq(&index);
let deserialized = deserialize_ivf_pq(&serialized).expect("expected value");
assert_eq!(deserialized.dimensions, 16);
assert_eq!(deserialized.config.ivf.n_clusters, 4);
assert_eq!(deserialized.config.pq.num_subspaces, 4);
assert!(!deserialized.trained);
}
#[test]
fn test_ivf_pq_serialize_round_trip() {
let mut index = IvfPqIndex::new(16, test_config()).expect("expected value");
let mut vectors = Vec::new();
for i in 0..500 {
for d in 0..16 {
vectors.push((i * 16 + d) as f32 / 8000.0);
}
}
index
.add_training_vectors(&vectors, 500)
.expect("expected value");
index.train().expect("expected value");
for i in 0..10 {
let vector: Vec<f32> = (0..16).map(|d| (i * 16 + d) as f32 / 160.0).collect();
index.insert(i as u64, &vector).expect("expected value");
}
let serialized = serialize_ivf_pq(&index);
let deserialized = deserialize_ivf_pq(&serialized).expect("expected value");
assert_eq!(deserialized.dimensions, index.dimensions);
assert_eq!(
deserialized.config.ivf.n_clusters,
index.config.ivf.n_clusters
);
assert_eq!(
deserialized.config.pq.num_subspaces,
index.config.pq.num_subspaces
);
assert_eq!(
deserialized.config.use_residuals,
index.config.use_residuals
);
assert!(deserialized.trained);
assert_eq!(deserialized.pq_codes.len(), 10);
assert!(deserialized.centroid_distances.is_some());
let orig_stats = index.stats();
let deser_stats = deserialized.stats();
assert_eq!(orig_stats.total_vectors, deser_stats.total_vectors);
}
#[test]
fn test_ivf_pq_serialize_invalid_magic() {
let mut buffer = vec![0u8; IVFPQ_HEADER_SIZE];
buffer[0..4].copy_from_slice(&0x00000000u32.to_le_bytes());
let result = deserialize_ivf_pq(&buffer);
assert!(matches!(result, Err(SerializeError::InvalidMagic { .. })));
}
#[test]
fn test_ivf_pq_serialize_buffer_underflow() {
let buffer = vec![]; let result = deserialize_ivf_pq(&buffer);
assert!(matches!(
result,
Err(SerializeError::BufferUnderflow { .. })
));
}
#[test]
fn test_ivf_pq_serialized_size() {
let mut index = IvfPqIndex::new(16, test_config()).expect("expected value");
let mut vectors = Vec::new();
for i in 0..500 {
for d in 0..16 {
vectors.push((i * 16 + d) as f32 / 8000.0);
}
}
index
.add_training_vectors(&vectors, 500)
.expect("expected value");
index.train().expect("expected value");
for i in 0..5 {
let vector: Vec<f32> = (0..16).map(|d| (i * 16 + d) as f32 / 80.0).collect();
index.insert(i as u64, &vector).expect("expected value");
}
let size = ivf_pq_serialized_size(&index);
let serialized = serialize_ivf_pq(&index);
assert_eq!(size, serialized.len());
}
#[test]
fn test_ivf_pq_search_multi_empty_queries() {
let mut index = IvfPqIndex::new(16, test_config()).expect("expected value");
let mut vectors = Vec::new();
for i in 0..500 {
for d in 0..16 {
vectors.push((i * 16 + d) as f32 / 8000.0);
}
}
index
.add_training_vectors(&vectors, 500)
.expect("expected value");
index.train().expect("expected value");
let config = crate::vector::types::VectorStoreConfig::new(16);
let manifest = crate::vector::types::VectorManifest::new(config);
let results = index.search_multi(&manifest, &[], 5, MultiQueryAggregation::Min, None);
assert!(results.is_empty());
}
#[test]
fn test_ivf_pq_search_multi_not_trained() {
let index = IvfPqIndex::new(16, test_config()).expect("expected value");
let config = crate::vector::types::VectorStoreConfig::new(16);
let manifest = crate::vector::types::VectorManifest::new(config);
let query = vec![0.5f32; 16];
let results = index.search_multi(&manifest, &[&query], 5, MultiQueryAggregation::Min, None);
assert!(results.is_empty());
}
}