use faer::{MatRef, RowRef};
use rayon::prelude::*;
use std::collections::BinaryHeap;
use std::ops::AddAssign;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use thousands::*;
use crate::prelude::*;
use crate::quantised::quantisers::*;
use crate::utils::k_means_utils::*;
use crate::utils::*;
pub struct IvfOpqIndex<T> {
quantised_codes: Vec<u8>,
dim: usize,
n: usize,
metric: Dist,
centroids: Vec<T>,
centroids_norm: Vec<T>,
all_indices: Vec<usize>,
offsets: Vec<usize>,
codebook: OptimisedProductQuantiser<T>,
nlist: usize,
original_ids: Vec<usize>,
}
impl<T> CentroidDistance<T> for IvfOpqIndex<T>
where
T: AnnSearchFloat,
{
fn centroids(&self) -> &[T] {
&self.centroids
}
fn dim(&self) -> usize {
self.dim
}
fn metric(&self) -> Dist {
self.metric
}
fn nlist(&self) -> usize {
self.nlist
}
fn centroids_norm(&self) -> &[T] {
&self.centroids_norm
}
}
impl<T> VectorDistanceAdc<T> for IvfOpqIndex<T>
where
T: AnnSearchFloat + AddAssign,
{
fn codebook_m(&self) -> usize {
self.codebook.m()
}
fn codebook_n_centroids(&self) -> usize {
self.codebook.n_centroids()
}
fn codebook_subvec_dim(&self) -> usize {
self.codebook.subvec_dim()
}
fn centroids(&self) -> &[T] {
&self.centroids
}
fn dim(&self) -> usize {
self.dim
}
fn quantised_codes(&self) -> &[u8] {
&self.quantised_codes
}
fn codebooks(&self) -> &[Vec<T>] {
self.codebook.codebooks()
}
}
impl<T> IvfOpqIndex<T>
where
T: AnnSearchFloat + AddAssign,
{
#[allow(clippy::too_many_arguments)]
pub fn build(
data: MatRef<T>,
nlist: Option<usize>,
m: usize,
metric: Dist,
max_iters: Option<usize>,
opq_iter: Option<usize>,
n_opq_centroids: Option<usize>,
seed: usize,
verbose: bool,
) -> Self {
let (mut vectors_flat, n, dim) = matrix_to_flat(data);
let max_iters = max_iters.unwrap_or(30);
let nlist = nlist.unwrap_or((n as f32).sqrt() as usize).max(1);
if metric == Dist::Cosine {
if verbose {
println!(" Normalising vectors for cosine distance");
}
vectors_flat
.par_chunks_mut(dim)
.for_each(|chunk| normalise_vector(chunk));
}
let n_train = (256 * nlist).min(250_000).min(n).max(1);
let (training_data, _) = sample_vectors(&vectors_flat, dim, n, n_train, seed);
if verbose {
println!(" Generating IVF-OPQ index with {} Voronoi cells.", nlist);
}
let mut centroids = train_centroids(
&training_data,
dim,
n_train,
nlist,
&metric,
max_iters,
seed,
verbose,
);
if metric == Dist::Cosine {
if verbose {
println!(" Normalising centroids");
}
centroids
.par_chunks_mut(dim)
.for_each(|chunk| normalise_vector(chunk));
}
if verbose {
println!(" Computing residuals for OPQ training");
}
let training_norms = vec![T::one(); n_train];
let centroid_norms = vec![T::one(); nlist];
let training_assignments = assign_all_parallel(
&training_data,
&training_norms,
dim,
n_train,
¢roids,
¢roid_norms,
nlist,
&metric,
);
let mut training_residuals = Vec::with_capacity(training_data.len());
for (vec_idx, &cluster_id) in training_assignments.iter().enumerate() {
let vec_start = vec_idx * dim;
let vec = &training_data[vec_start..vec_start + dim];
let centroid = ¢roids[cluster_id * dim..(cluster_id + 1) * dim];
let residuals = T::subtract_simd(vec, centroid);
training_residuals.extend_from_slice(&residuals);
}
if verbose {
println!(" Training optimised product quantiser with m={}", m);
}
let codebook = OptimisedProductQuantiser::train(
&training_residuals,
dim,
m,
n_opq_centroids,
opq_iter,
max_iters,
seed + 1000,
verbose,
);
let data_norms = vec![T::one(); n];
let assignments = assign_all_parallel(
&vectors_flat,
&data_norms,
dim,
n,
¢roids,
¢roid_norms,
nlist,
&metric,
);
if verbose {
print_cluster_summary(&assignments, nlist);
}
let (all_indices, offsets) = build_csr_layout(assignments.clone(), n, nlist);
if verbose {
println!(" Encoding residuals");
}
let mut quantised_codes = vec![0u8; n * m];
quantised_codes
.par_chunks_mut(m)
.zip(assignments.par_iter())
.enumerate()
.for_each(|(vec_idx, (chunk, &cluster_id))| {
let vec_start = vec_idx * dim;
let vec = &vectors_flat[vec_start..vec_start + dim];
let centroid = ¢roids[cluster_id * dim..(cluster_id + 1) * dim];
let residual = T::subtract_simd(vec, centroid);
let codes = codebook.encode(&residual);
chunk.copy_from_slice(&codes);
});
if verbose {
println!(" (Optimised) Quantisation complete");
}
let mut idx = Self {
quantised_codes,
centroids,
all_indices,
offsets,
codebook,
dim,
n,
nlist,
metric,
centroids_norm: Vec::new(),
original_ids: Vec::new(),
};
let new_to_old = idx.optimise_memory_layout();
idx.original_ids = new_to_old;
idx
}
#[inline]
pub fn query(&self, query_vec: &[T], k: usize, nprobe: Option<usize>) -> (Vec<usize>, Vec<T>) {
let mut query_vec = query_vec.to_vec();
if self.metric == Dist::Cosine {
normalise_vector(&mut query_vec);
}
let nprobe = nprobe
.unwrap_or_else(|| ((self.nlist as f64).sqrt() as usize).max(1))
.min(self.nlist);
let k = k.min(self.n);
let cluster_scores: Vec<(T, usize)> = self.get_centroids_prenorm(&query_vec, nprobe);
let mut heap: BinaryHeap<(OrderedFloat<T>, usize)> = BinaryHeap::with_capacity(k + 1);
for &(_, cluster_idx) in cluster_scores.iter().take(nprobe) {
let lookup_tables = self.build_lookup_tables_residual(&query_vec, cluster_idx);
let start = self.offsets[cluster_idx];
let end = self.offsets[cluster_idx + 1];
let mut worst_dist = if heap.len() >= k {
heap.peek().unwrap().0 .0
} else {
T::infinity()
};
for vec_idx in start..end {
let dist = self.compute_distance_adc(vec_idx, &lookup_tables);
if dist >= worst_dist {
continue;
}
if heap.len() < k {
heap.push((OrderedFloat(dist), vec_idx));
if heap.len() == k {
worst_dist = heap.peek().unwrap().0 .0;
}
} else {
heap.pop();
heap.push((OrderedFloat(dist), vec_idx));
worst_dist = heap.peek().unwrap().0 .0;
}
}
}
let mut results: Vec<_> = heap.into_iter().collect();
results.sort_unstable_by_key(|&(dist, _)| dist);
let (distances, indices) = results
.into_iter()
.map(|(d, i)| (d.0, self.original_ids[i]))
.unzip();
(indices, distances)
}
#[inline]
pub fn query_row(
&self,
query_row: RowRef<T>,
k: usize,
nprobe: Option<usize>,
) -> (Vec<usize>, Vec<T>) {
if query_row.col_stride() == 1 {
let slice =
unsafe { std::slice::from_raw_parts(query_row.as_ptr(), query_row.ncols()) };
return self.query(slice, k, nprobe);
}
let query_vec: Vec<T> = query_row.iter().cloned().collect();
self.query(&query_vec, k, nprobe)
}
pub fn generate_knn(
&self,
k: usize,
nprobe: Option<usize>,
return_dist: bool,
verbose: bool,
) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>) {
let m = self.codebook.m();
let counter = Arc::new(AtomicUsize::new(0));
let mut cluster_assignments = vec![0usize; self.n];
for cluster_idx in 0..self.nlist {
let start = self.offsets[cluster_idx];
let end = self.offsets[cluster_idx + 1];
for new_idx in start..end {
cluster_assignments[new_idx] = cluster_idx;
}
}
let unordered_results: Vec<(usize, Vec<usize>, Vec<T>)> = (0..self.n)
.into_par_iter()
.map(|i| {
let my_cluster = cluster_assignments[i];
let codes = &self.quantised_codes[i * m..(i + 1) * m];
let my_centroid =
&self.centroids[my_cluster * self.dim..(my_cluster + 1) * self.dim];
let residual = self.codebook.decode(codes);
let reconstructed: Vec<T> = T::add_simd(my_centroid, &residual);
let orig_id = self.original_ids[i];
if verbose {
let count = counter.fetch_add(1, Ordering::Relaxed) + 1;
if count.is_multiple_of(100_000) {
println!(
" Processed {} / {} samples.",
count.separate_with_underscores(),
self.n.separate_with_underscores()
);
}
}
let (indices, dists) = self.query(&reconstructed, k, nprobe);
(orig_id, indices, dists)
})
.collect();
let mut final_indices = vec![Vec::new(); self.n];
let mut final_dists = if return_dist {
Some(vec![Vec::new(); self.n])
} else {
None
};
for (orig_id, indices, dists) in unordered_results {
final_indices[orig_id] = indices;
if let Some(ref mut fd) = final_dists {
fd[orig_id] = dists;
}
}
(final_indices, final_dists)
}
pub fn memory_usage_bytes(&self) -> usize {
std::mem::size_of_val(self)
+ self.quantised_codes.capacity() * std::mem::size_of::<u8>()
+ self.centroids.capacity() * std::mem::size_of::<T>()
+ self.centroids_norm.capacity() * std::mem::size_of::<T>()
+ self.all_indices.capacity() * std::mem::size_of::<usize>()
+ self.offsets.capacity() * std::mem::size_of::<usize>()
+ self.codebook.memory_usage_bytes()
}
fn optimise_memory_layout(&mut self) -> Vec<usize> {
let m = self.codebook.m();
let mut new_to_old = Vec::with_capacity(self.n);
let mut old_to_new = vec![0usize; self.n];
for cluster in 0..self.nlist {
let start = self.offsets[cluster];
let end = self.offsets[cluster + 1];
for &old_id in &self.all_indices[start..end] {
old_to_new[old_id] = new_to_old.len();
new_to_old.push(old_id);
}
}
let mut new_codes = Vec::with_capacity(self.quantised_codes.len());
for &old_id in &new_to_old {
let start = old_id * m;
new_codes.extend_from_slice(&self.quantised_codes[start..start + m]);
}
self.quantised_codes = new_codes;
self.all_indices.clear();
self.all_indices.shrink_to_fit();
new_to_old
}
}
#[cfg(test)]
mod tests {
use super::*;
use faer::Mat;
fn create_simple_dataset() -> Mat<f32> {
let mut data = Vec::new();
for i in 0..3 {
for j in 0..32 {
data.push(i as f32 * 0.1 + j as f32 * 0.01);
}
}
for i in 0..3 {
for j in 0..32 {
data.push(10.0 + i as f32 * 0.1 + j as f32 * 0.01);
}
}
Mat::from_fn(6, 32, |i, j| data[i * 32 + j])
}
#[test]
fn test_build_euclidean() {
let data = create_simple_dataset();
let index = IvfOpqIndex::build(
data.as_ref(),
Some(2),
8,
Dist::Euclidean,
Some(10),
Some(1),
Some(4),
42,
false,
);
assert_eq!(index.dim, 32);
assert_eq!(index.n, 6);
assert_eq!(index.nlist, 2);
assert_eq!(index.metric, Dist::Euclidean);
assert_eq!(index.quantised_codes.len(), 48);
assert_eq!(index.centroids.len(), 64);
assert_eq!(index.offsets.len(), 3);
}
#[test]
fn test_build_cosine() {
let data = create_simple_dataset();
let index = IvfOpqIndex::build(
data.as_ref(),
Some(2),
8,
Dist::Cosine,
Some(10),
Some(1),
Some(4),
42,
false,
);
assert_eq!(index.metric, Dist::Cosine);
}
#[test]
fn test_query_returns_k_results() {
let data = create_simple_dataset();
let index = IvfOpqIndex::build(
data.as_ref(),
Some(2),
8,
Dist::Euclidean,
Some(10),
Some(1),
Some(4),
42,
false,
);
let query: Vec<f32> = (0..32).map(|x| x as f32 * 0.01).collect();
let (indices, distances) = index.query(&query, 3, None);
assert_eq!(indices.len(), 3);
assert_eq!(distances.len(), 3);
}
#[test]
fn test_query_k_exceeds_n() {
let data = create_simple_dataset();
let index = IvfOpqIndex::build(
data.as_ref(),
Some(2),
8,
Dist::Euclidean,
Some(10),
Some(1),
Some(4),
42,
false,
);
let query: Vec<f32> = (0..32).map(|x| x as f32 * 0.01).collect();
let (indices, _) = index.query(&query, 100, None);
assert!(indices.len() <= 6);
}
#[test]
fn test_query_distances_sorted() {
let data = create_simple_dataset();
let index = IvfOpqIndex::build(
data.as_ref(),
Some(2),
8,
Dist::Euclidean,
Some(10),
Some(1),
Some(4),
42,
false,
);
let query: Vec<f32> = (0..32).map(|x| x as f32 * 0.01).collect();
let (_, distances) = index.query(&query, 3, Some(2));
for i in 1..distances.len() {
assert!(distances[i] >= distances[i - 1]);
}
}
#[test]
fn test_query_cosine() {
let data = create_simple_dataset();
let index = IvfOpqIndex::build(
data.as_ref(),
Some(2),
8,
Dist::Cosine,
Some(10),
Some(1),
Some(4),
42,
false,
);
let query: Vec<f32> = (0..32).map(|x| if x < 16 { 1.0 } else { 0.0 }).collect();
let (indices, distances) = index.query(&query, 3, None);
assert_eq!(indices.len(), 3);
assert_eq!(distances.len(), 3);
}
#[test]
fn test_query_different_nprobe() {
let data = create_simple_dataset();
let index = IvfOpqIndex::build(
data.as_ref(),
Some(2),
8,
Dist::Euclidean,
Some(10),
Some(1),
Some(4),
42,
false,
);
let query: Vec<f32> = (0..32).map(|x| 5.0 + x as f32 * 0.01).collect();
let (indices1, _) = index.query(&query, 3, Some(1));
let (indices2, _) = index.query(&query, 3, Some(2));
assert_eq!(indices1.len(), 3);
assert_eq!(indices2.len(), 3);
}
#[test]
fn test_query_deterministic() {
let data = create_simple_dataset();
let index = IvfOpqIndex::build(
data.as_ref(),
Some(2),
8,
Dist::Euclidean,
Some(10),
Some(1),
Some(4),
42,
false,
);
let query: Vec<f32> = (0..32).map(|x| 0.5 + x as f32 * 0.01).collect();
let (indices1, distances1) = index.query(&query, 3, Some(2));
let (indices2, distances2) = index.query(&query, 3, Some(2));
assert_eq!(indices1, indices2);
assert_eq!(distances1, distances2);
}
#[test]
fn test_query_row() {
let data = create_simple_dataset();
let index = IvfOpqIndex::build(
data.as_ref(),
Some(2),
8,
Dist::Euclidean,
Some(10),
Some(1),
Some(4),
42,
false,
);
let query_mat = Mat::<f32>::from_fn(1, 32, |_, j| 0.5 + j as f32 * 0.01);
let row = query_mat.row(0);
let (indices, distances) = index.query_row(row, 3, None);
assert_eq!(indices.len(), 3);
assert_eq!(distances.len(), 3);
}
#[test]
fn test_build_different_m() {
let data = Mat::from_fn(20, 32, |i, j| (i + j) as f32);
let index = IvfOpqIndex::build(
data.as_ref(),
Some(2),
8,
Dist::Euclidean,
Some(5),
Some(1),
Some(4),
42,
false,
);
assert_eq!(index.codebook.m(), 8);
assert_eq!(index.codebook.subvec_dim(), 4);
assert_eq!(index.quantised_codes.len(), 160);
}
#[test]
fn test_build_lookup_tables() {
let data = create_simple_dataset();
let index = IvfOpqIndex::build(
data.as_ref(),
Some(2),
8,
Dist::Euclidean,
Some(10),
Some(1),
Some(4),
42,
false,
);
let query: Vec<f32> = (0..32).map(|x| x as f32 * 0.01).collect();
let table = index.build_lookup_tables_residual(&query, 0);
assert_eq!(table.len(), 32);
}
#[test]
fn test_compute_distance_adc() {
let data = create_simple_dataset();
let index = IvfOpqIndex::build(
data.as_ref(),
Some(2),
8,
Dist::Euclidean,
Some(10),
Some(1),
Some(4),
42,
false,
);
let query: Vec<f32> = (0..32).map(|x| x as f32 * 0.01).collect();
let table = index.build_lookup_tables_residual(&query, 0);
let dist = index.compute_distance_adc(0, &table);
assert!(dist >= 0.0);
}
#[test]
fn test_opq_iterations() {
let data = Mat::from_fn(50, 32, |i, j| (i + j) as f32);
let index = IvfOpqIndex::build(
data.as_ref(),
Some(5),
8,
Dist::Euclidean,
Some(5),
Some(2),
Some(4),
42,
false,
);
assert_eq!(index.codebook.m(), 8);
}
#[test]
fn test_residual_encoding() {
let data = Mat::from_fn(50, 32, |i, j| (i + j) as f32);
let index = IvfOpqIndex::build(
data.as_ref(),
Some(5),
8,
Dist::Euclidean,
Some(5),
Some(1),
Some(4),
42,
false,
);
let query: Vec<f32> = (0..32).map(|x| x as f32).collect();
let (indices, _) = index.query(&query, 1, Some(5));
assert_eq!(indices[0], 0);
}
}