use std::cmp::min;
use lance_linalg::distance::{dot_distance_batch, l2_distance_batch, Dot, L2};
use super::{num_centroids, utils::get_sub_vector_centroids};
pub(super) fn build_distance_table_l2<T: L2>(
codebook: &[T],
num_bits: u32,
num_sub_vectors: usize,
query: &[T],
) -> Vec<f32> {
let dimension = query.len();
let sub_vector_length = dimension / num_sub_vectors;
query
.chunks_exact(sub_vector_length)
.enumerate()
.flat_map(|(i, sub_vec)| {
let subvec_centroids =
get_sub_vector_centroids(codebook, dimension, num_bits, num_sub_vectors, i);
l2_distance_batch(sub_vec, subvec_centroids, sub_vector_length)
})
.collect()
}
pub(super) fn build_distance_table_dot<T: Dot>(
codebook: &[T],
num_bits: u32,
num_sub_vectors: usize,
query: &[T],
) -> Vec<f32> {
let dimension = query.len();
let sub_vector_length = dimension / num_sub_vectors;
query
.chunks_exact(sub_vector_length)
.enumerate()
.flat_map(|(i, sub_vec)| {
let subvec_centroids =
get_sub_vector_centroids(codebook, dimension, num_bits, num_sub_vectors, i);
dot_distance_batch(sub_vec, subvec_centroids, sub_vector_length)
})
.collect()
}
#[inline]
pub(super) fn compute_l2_distance(
distance_table: &[f32],
num_bits: u32,
num_sub_vectors: usize,
code: &[u8],
) -> Vec<f32> {
let num_vectors = code.len() / num_sub_vectors;
let mut distances = vec![0.0_f32; num_vectors];
let num_centroids = num_centroids(num_bits);
for (sub_vec_idx, vec_indices) in code.chunks_exact(num_vectors).enumerate() {
let dist_table = &distance_table[sub_vec_idx * num_centroids..];
vec_indices
.iter()
.zip(distances.iter_mut())
.for_each(|(¢roid_idx, sum)| {
*sum += dist_table[centroid_idx as usize];
});
}
distances
}
#[allow(dead_code)]
fn compute_l2_distance_without_transposing<const C: usize, const V: usize>(
distance_table: &[f32],
num_bits: u32,
num_sub_vectors: usize,
code: &[u8],
) -> Vec<f32> {
let num_centroids = num_centroids(num_bits);
let iter = code.chunks_exact(num_sub_vectors * V);
let distances = iter.clone().flat_map(|c| {
let mut sums = [0.0_f32; V];
for i in (0..num_sub_vectors).step_by(C) {
for (vec_idx, sum) in sums.iter_mut().enumerate() {
let vec_start = vec_idx * num_sub_vectors;
let s = c[vec_start + i..]
.iter()
.take(min(C, num_sub_vectors - i))
.enumerate()
.map(|(k, c)| distance_table[(i + k) * num_centroids + *c as usize])
.sum::<f32>();
*sum += s;
}
}
sums.into_iter()
});
let remainder = iter.remainder().chunks(num_sub_vectors).map(|c| {
c.iter()
.enumerate()
.map(|(sub_vec_idx, code)| distance_table[sub_vec_idx * num_centroids + *code as usize])
.sum::<f32>()
});
distances.chain(remainder).collect()
}
#[cfg(test)]
mod tests {
use crate::vector::pq::storage::transpose;
use super::*;
use arrow_array::UInt8Array;
#[test]
fn test_compute_on_transposed_codes() {
let num_vectors = 100;
let num_sub_vectors = 4;
let num_bits = 8;
let dimension = 16;
let codebook =
Vec::from_iter((0..num_sub_vectors * num_vectors * dimension).map(|v| v as f32));
let query = Vec::from_iter((0..dimension).map(|v| v as f32));
let distance_table = build_distance_table_l2(&codebook, num_bits, num_sub_vectors, &query);
let pq_codes = Vec::from_iter((0..num_vectors * num_sub_vectors).map(|v| v as u8));
let pq_codes = UInt8Array::from_iter_values(pq_codes);
let transposed_codes = transpose(&pq_codes, num_vectors, num_sub_vectors);
let distances = compute_l2_distance(
&distance_table,
num_bits,
num_sub_vectors,
transposed_codes.values(),
);
let expected = compute_l2_distance_without_transposing::<4, 1>(
&distance_table,
num_bits,
num_sub_vectors,
pq_codes.values(),
);
assert_eq!(distances, expected);
}
}