use faer::{MatRef, RowRef};
use rayon::prelude::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use thousands::Separable;
use crate::prelude::*;
use crate::utils::k_means_utils::*;
use crate::utils::*;
pub struct IvfIndex<T> {
pub vectors_flat: Vec<T>,
pub dim: usize,
pub n: usize,
pub norms: Vec<T>,
metric: Dist,
centroids: Vec<T>,
centroids_norm: Vec<T>,
all_indices: Vec<usize>,
offsets: Vec<usize>,
nlist: usize,
original_ids: Vec<usize>,
}
impl<T> VectorDistance<T> for IvfIndex<T>
where
T: AnnSearchFloat,
{
fn vectors_flat(&self) -> &[T] {
&self.vectors_flat
}
fn dim(&self) -> usize {
self.dim
}
fn norms(&self) -> &[T] {
&self.norms
}
}
impl<T> CentroidDistance<T> for IvfIndex<T>
where
T: AnnSearchFloat,
{
fn centroids(&self) -> &[T] {
&self.centroids
}
fn dim(&self) -> usize {
self.dim
}
fn metric(&self) -> Dist {
self.metric
}
fn nlist(&self) -> usize {
self.nlist
}
fn centroids_norm(&self) -> &[T] {
&self.centroids_norm
}
}
impl<T> IvfIndex<T>
where
T: AnnSearchFloat,
{
pub fn build(
data: MatRef<T>,
metric: Dist,
nlist: Option<usize>,
max_iters: Option<usize>,
seed: usize,
verbose: bool,
) -> Self {
let (vectors_flat, n, dim) = matrix_to_flat(data);
let norms = if metric == Dist::Cosine {
(0..n)
.map(|i| {
let start = i * dim;
let end = start + dim;
T::calculate_l2_norm(&vectors_flat[start..end])
})
.collect()
} else {
Vec::new()
};
let max_iters = max_iters.unwrap_or(30);
let nlist = nlist.unwrap_or((n as f32).sqrt() as usize).max(1);
let n_train = (256 * nlist).min(250_000).min(n).max(1);
let (training_data, _) = sample_vectors(&vectors_flat, dim, n, n_train, seed);
if verbose {
println!(" Generating IVF index with {} Voronoi cells.", nlist);
}
let centroids = train_centroids(
&training_data,
dim,
n_train,
nlist,
&metric,
max_iters,
seed,
verbose,
);
let centroids_norm = if metric == Dist::Cosine {
(0..nlist)
.map(|i| {
let start = i * dim;
let end = start + dim;
centroids[start..end]
.iter()
.map(|x| *x * *x)
.fold(T::zero(), |a, b| a + b)
.sqrt()
})
.collect()
} else {
Vec::new()
};
let data_norms_for_assignment = if metric == Dist::Cosine {
norms.clone()
} else {
vec![T::one(); n]
};
let assignments = assign_all_parallel(
&vectors_flat,
&data_norms_for_assignment,
dim,
n,
¢roids,
¢roids_norm,
nlist,
&metric,
);
if verbose {
print_cluster_summary(&assignments, nlist);
}
let (all_indices, offsets) = build_csr_layout(assignments, n, nlist);
let mut idx = Self {
vectors_flat,
dim,
n,
norms,
metric,
centroids,
centroids_norm,
all_indices,
offsets,
nlist,
original_ids: Vec::new(),
};
let new_to_old = idx.optimise_memory_layout();
idx.original_ids = new_to_old;
idx
}
#[inline]
pub fn query(&self, query_vec: &[T], k: usize, nprobe: Option<usize>) -> (Vec<usize>, Vec<T>) {
let nprobe = nprobe
.unwrap_or_else(|| ((self.nlist as f64).sqrt() as usize).max(1))
.min(self.nlist);
let k: usize = k.min(self.n);
let query_norm = if matches!(self.metric, Dist::Cosine) {
query_vec
.iter()
.map(|&v| v * v)
.fold(T::zero(), |a, b| a + b)
.sqrt()
} else {
T::one()
};
let cluster_dists: Vec<(T, usize)> = self.get_centroids_dist(query_vec, query_norm, nprobe);
let mut buffer = SortedBuffer::with_capacity(k);
for &(_, cluster_idx) in cluster_dists.iter().take(nprobe) {
let start = self.offsets[cluster_idx];
let end = self.offsets[cluster_idx + 1];
for vec_idx in start..end {
let dist = match self.metric {
Dist::Euclidean => self.euclidean_distance_to_query(vec_idx, query_vec),
Dist::Cosine => self.cosine_distance_to_query(vec_idx, query_vec, query_norm),
};
buffer.insert((OrderedFloat(dist), vec_idx), k);
}
}
let (distances, indices) = buffer
.data()
.iter()
.map(|(d, i)| (d.0, self.original_ids[*i]))
.unzip();
(indices, distances)
}
#[inline]
pub fn query_row(
&self,
query_row: RowRef<T>,
k: usize,
nprobe: Option<usize>,
) -> (Vec<usize>, Vec<T>) {
if query_row.col_stride() == 1 {
let slice =
unsafe { std::slice::from_raw_parts(query_row.as_ptr(), query_row.ncols()) };
return self.query(slice, k, nprobe);
}
let query_vec: Vec<T> = query_row.iter().cloned().collect();
self.query(&query_vec, k, nprobe)
}
pub fn generate_knn(
&self,
k: usize,
nprobe: Option<usize>,
return_dist: bool,
verbose: bool,
) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>) {
let counter = Arc::new(AtomicUsize::new(0));
let unordered_results: Vec<(usize, Vec<usize>, Vec<T>)> = (0..self.n)
.into_par_iter()
.map(|i| {
let start = i * self.dim;
let end = start + self.dim;
let vec = &self.vectors_flat[start..end];
let orig_id = self.original_ids[i];
if verbose {
let count = counter.fetch_add(1, Ordering::Relaxed) + 1;
if count.is_multiple_of(100_000) {
println!(
" Processed {} / {} samples.",
count.separate_with_underscores(),
self.n.separate_with_underscores()
);
}
}
let (indices, dists) = self.query(vec, k, nprobe);
(orig_id, indices, dists)
})
.collect();
let mut final_indices = vec![Vec::new(); self.n];
let mut final_dists = if return_dist {
Some(vec![Vec::new(); self.n])
} else {
None
};
for (orig_id, indices, dists) in unordered_results {
final_indices[orig_id] = indices;
if let Some(ref mut fd) = final_dists {
fd[orig_id] = dists;
}
}
(final_indices, final_dists)
}
fn optimise_memory_layout(&mut self) -> Vec<usize> {
let mut new_to_old = Vec::with_capacity(self.n);
let mut old_to_new = vec![0usize; self.n];
for cluster in 0..self.nlist {
let start = self.offsets[cluster];
let end = self.offsets[cluster + 1];
for &old_id in &self.all_indices[start..end] {
old_to_new[old_id] = new_to_old.len();
new_to_old.push(old_id);
}
}
let mut new_vectors_flat = Vec::with_capacity(self.vectors_flat.len());
let mut new_norms = if self.norms.is_empty() {
Vec::new()
} else {
Vec::with_capacity(self.n)
};
for &old_id in &new_to_old {
let start = old_id * self.dim;
new_vectors_flat.extend_from_slice(&self.vectors_flat[start..start + self.dim]);
if !self.norms.is_empty() {
new_norms.push(self.norms[old_id]);
}
}
self.vectors_flat = new_vectors_flat;
self.norms = new_norms;
self.all_indices.clear();
self.all_indices.shrink_to_fit();
new_to_old
}
pub fn memory_usage_bytes(&self) -> usize {
std::mem::size_of_val(self)
+ self.vectors_flat.capacity() * std::mem::size_of::<T>()
+ self.norms.capacity() * std::mem::size_of::<T>()
+ self.centroids.capacity() * std::mem::size_of::<T>()
+ self.centroids_norm.capacity() * std::mem::size_of::<T>()
+ self.all_indices.capacity() * std::mem::size_of::<usize>()
+ self.offsets.capacity() * std::mem::size_of::<usize>()
}
}
impl<T> KnnValidation<T> for IvfIndex<T>
where
T: AnnSearchFloat,
{
fn query_for_validation(&self, query_vec: &[T], k: usize) -> (Vec<usize>, Vec<T>) {
self.query(query_vec, k, None)
}
fn n(&self) -> usize {
self.n
}
fn metric(&self) -> Dist {
self.metric
}
fn original_ids(&self) -> &[usize] {
&self.original_ids
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use faer::Mat;
fn create_simple_matrix() -> Mat<f32> {
let data = [
1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, ];
Mat::from_fn(5, 3, |i, j| data[i * 3 + j])
}
#[test]
fn test_ivf_index_creation() {
let data = create_simple_matrix();
let _ = IvfIndex::build(
data.as_ref(),
Dist::Euclidean,
Some(2), None,
42,
false,
);
}
#[test]
fn test_ivf_query_finds_self() {
let data = create_simple_matrix();
let index = IvfIndex::build(data.as_ref(), Dist::Euclidean, Some(2), None, 42, false);
let query = vec![1.0, 0.0, 0.0];
let (indices, distances) = index.query(&query, 1, None);
assert_eq!(indices.len(), 1);
assert_eq!(indices[0], 0);
assert_relative_eq!(distances[0], 0.0, epsilon = 1e-5);
}
#[test]
fn test_ivf_query_euclidean() {
let data = create_simple_matrix();
let index = IvfIndex::build(data.as_ref(), Dist::Euclidean, Some(2), None, 42, false);
let query = vec![1.0, 0.0, 0.0];
let (indices, distances) = index.query(&query, 3, None);
assert_eq!(indices[0], 0);
assert_relative_eq!(distances[0], 0.0, epsilon = 1e-5);
for i in 1..distances.len() {
assert!(distances[i] >= distances[i - 1]);
}
}
#[test]
fn test_ivf_query_cosine() {
let data = create_simple_matrix();
let index = IvfIndex::build(data.as_ref(), Dist::Cosine, Some(2), None, 42, false);
let query = vec![1.0, 0.0, 0.0];
let (indices, distances) = index.query(&query, 3, None);
assert_eq!(indices[0], 0);
assert_relative_eq!(distances[0], 0.0, epsilon = 1e-5);
}
#[test]
fn test_ivf_query_k_larger_than_dataset() {
let data = create_simple_matrix();
let index = IvfIndex::build(data.as_ref(), Dist::Euclidean, Some(2), None, 42, false);
let query = vec![1.0, 0.0, 0.0];
let (indices, _) = index.query(&query, 10, None);
assert!(indices.len() <= 5);
}
#[test]
fn test_ivf_query_nprobe() {
let data = create_simple_matrix();
let index = IvfIndex::build(data.as_ref(), Dist::Euclidean, Some(3), None, 42, false);
let query = vec![1.0, 0.0, 0.0];
let (indices1, _) = index.query(&query, 3, Some(1));
let (indices2, _) = index.query(&query, 3, Some(2));
assert!(!indices1.is_empty());
assert!(!indices2.is_empty());
}
#[test]
fn test_ivf_reproducibility() {
let data = create_simple_matrix();
let index1 = IvfIndex::build(data.as_ref(), Dist::Euclidean, Some(2), None, 42, false);
let index2 = IvfIndex::build(data.as_ref(), Dist::Euclidean, Some(2), None, 42, false);
let query = vec![0.5, 0.5, 0.0];
let (indices1, _) = index1.query(&query, 3, None);
let (indices2, _) = index2.query(&query, 3, None);
assert_eq!(indices1, indices2);
}
#[test]
fn test_ivf_different_seeds() {
let data = create_simple_matrix();
let index1 = IvfIndex::build(data.as_ref(), Dist::Euclidean, Some(2), None, 42, false);
let index2 = IvfIndex::build(data.as_ref(), Dist::Euclidean, Some(2), None, 123, false);
let query = vec![0.5, 0.5, 0.0];
let (indices1, _) = index1.query(&query, 3, Some(2));
let (indices2, _) = index2.query(&query, 3, Some(2));
assert!(!indices1.is_empty());
assert!(!indices2.is_empty());
}
#[test]
fn test_ivf_larger_dataset() {
let n = 100;
let dim = 10;
let data = Mat::from_fn(n, dim, |i, j| (i * j) as f32 / 10.0);
let index = IvfIndex::build(
data.as_ref(),
Dist::Euclidean,
Some(10), None,
42,
false,
);
let query: Vec<f32> = (0..dim).map(|_| 0.0).collect();
let (indices, _) = index.query(&query, 5, None);
assert_eq!(indices.len(), 5);
assert_eq!(indices[0], 0);
}
#[test]
fn test_ivf_orthogonal_vectors() {
let data = Mat::from_fn(3, 3, |i, j| if i == j { 1.0 } else { 0.0 });
let index = IvfIndex::build(data.as_ref(), Dist::Cosine, Some(3), None, 42, false);
let query = vec![1.0, 0.0, 0.0];
let (indices, distances) = index.query(&query, 3, None);
assert_eq!(indices[0], 0);
assert_relative_eq!(distances[0], 0.0, epsilon = 1e-5);
if indices.len() >= 2 {
assert_relative_eq!(distances[1], 1.0, epsilon = 1e-5);
}
if indices.len() >= 3 {
assert_relative_eq!(distances[2], 1.0, epsilon = 1e-5);
}
}
#[test]
fn test_ivf_more_clusters() {
let data = create_simple_matrix();
let index_few = IvfIndex::build(data.as_ref(), Dist::Euclidean, Some(2), None, 42, false);
let index_many = IvfIndex::build(data.as_ref(), Dist::Euclidean, Some(4), None, 42, false);
let query = vec![0.9, 0.1, 0.0];
let (indices1, _) = index_few.query(&query, 3, Some(2));
let (indices2, _) = index_many.query(&query, 3, Some(4));
assert_eq!(indices1.len(), 3);
assert_eq!(indices2.len(), 3);
}
}