#![warn(missing_debug_implementations)]
use std::{
io::{Seek, SeekFrom, Write},
mem::size_of,
sync::atomic::AtomicBool,
vec,
};
use crate::storage::{StorageReadProvider, StorageWriteProvider};
use diskann::{
ANNError, ANNResult,
error::IntoANNResult,
utils::{VectorRepr, read_exact_into},
};
use diskann_quantization::{
CompressInto,
product::{BasicTableView, TransposedTable, train::TrainQuantizer},
};
use diskann_utils::{
io::Metadata,
views::{MatrixView, MutMatrixView},
};
use rand::{Rng, distr::Distribution};
use rayon::prelude::*;
use tracing::info;
use crate::{
model::GeneratePivotArguments,
storage::PQStorage,
utils::{
BridgeErr, ParallelIteratorInPool, RandomProvider, RayonThreadPoolRef, Timer,
create_rnd_provider_from_seed,
},
};
pub const MAX_PQ_TRAINING_SET_SIZE: f64 = 50_000f64;
pub const NUM_PQ_CENTROIDS: usize = 256;
pub const NUM_KMEANS_REPS_PQ: usize = 12;
impl<R> diskann_quantization::random::RngBuilder<usize> for RandomProvider<R>
where
R: Rng,
{
type Rng = R;
fn build_rng(&self, chunk_index: usize) -> Self::Rng {
self.create_rnd_from_seed(chunk_index as u64)
}
}
pub fn generate_pq_pivots<Storage, Random>(
parameters: GeneratePivotArguments,
train_data: &mut [f32],
pq_storage: &PQStorage,
storage_provider: &Storage,
random_provider: RandomProvider<Random>,
pool: RayonThreadPoolRef<'_>,
) -> ANNResult<()>
where
Storage: StorageWriteProvider + StorageReadProvider,
Random: Rng,
{
if pq_storage.pivot_data_exist(storage_provider) {
let (file_num_centers, file_dim) =
pq_storage.read_existing_pivot_metadata(storage_provider)?;
if file_dim == parameters.dim() && file_num_centers == parameters.num_centers() {
return Ok(());
}
}
let mut centroid: Vec<f32> = vec![0.0; parameters.dim()];
if parameters.translate_to_center() {
move_train_data_by_centroid(
train_data,
parameters.num_train(),
parameters.dim(),
&mut centroid,
);
}
let mut chunk_offsets: Vec<usize> = vec![0; parameters.num_pq_chunks() + 1];
calculate_chunk_offsets(
parameters.dim(),
parameters.num_pq_chunks(),
&mut chunk_offsets,
);
let trainer = diskann_quantization::product::train::LightPQTrainingParameters::new(
parameters.num_centers(),
parameters.max_k_means_reps(),
);
let full_pivot_data = pool.install(|| -> Result<Vec<f32>, ANNError> {
let result = trainer
.train(
MatrixView::try_from(train_data, parameters.num_train(), parameters.dim())
.bridge_err()?,
diskann_quantization::views::ChunkOffsetsView::new(chunk_offsets.as_slice())
.bridge_err()?,
diskann_quantization::Parallelism::Rayon,
&random_provider,
&diskann_quantization::cancel::DontCancel,
)
.map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))?
.flatten();
Ok(result)
})?;
pq_storage.write_pivot_data(
&full_pivot_data,
¢roid,
&chunk_offsets,
parameters.num_centers(),
parameters.dim(),
storage_provider,
)?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn generate_pq_pivots_from_membuf<T: Copy + Into<f32>>(
parameters: &GeneratePivotArguments,
train_data_slice: &[T],
centroid: &mut [f32],
offsets: &mut [usize],
full_pivot_data: &mut [f32],
rng: &mut (impl Rng + ?Sized),
cancellation_token: &mut bool,
pool: RayonThreadPoolRef<'_>,
) -> ANNResult<()> {
if full_pivot_data.len() != parameters.num_centers() * parameters.dim() {
return Err(ANNError::log_pq_error(
"Error: full_pivot_data size is not num_centers * dim.",
));
}
if centroid.len() != parameters.dim() {
return Err(ANNError::log_pq_error(
"Error: centroid size is not equal to dim.",
));
}
if offsets.len() != parameters.num_pq_chunks() + 1 {
return Err(ANNError::log_pq_error(
"Error: invalid offsets buffer input size.",
));
}
if *cancellation_token {
return Err(ANNError::log_pq_error(
"Error: Cancellation requested by caller.",
));
}
let mut train_data = train_data_slice
.iter()
.map(|x| (*x).into())
.collect::<Vec<f32>>();
if parameters.translate_to_center() {
move_train_data_by_centroid(
&mut train_data,
parameters.num_train(),
parameters.dim(),
centroid,
);
} else {
for val in centroid.iter_mut() {
*val = 0.0;
}
}
calculate_chunk_offsets(parameters.dim(), parameters.num_pq_chunks(), offsets);
let trainer = diskann_quantization::product::train::LightPQTrainingParameters::new(
parameters.num_centers(),
parameters.max_k_means_reps(),
);
let rng_builder = create_rnd_provider_from_seed(rand::distr::StandardUniform {}.sample(rng));
let trained = pool.install(|| -> Result<Vec<f32>, ANNError> {
let atomic_bool: &AtomicBool = unsafe { AtomicBool::from_ptr(cancellation_token) };
let cancelation = diskann_quantization::cancel::AtomicCancelation::new(atomic_bool);
let result = trainer
.train(
MatrixView::try_from(
train_data.as_slice(),
parameters.num_train(),
parameters.dim(),
)
.bridge_err()?,
diskann_quantization::views::ChunkOffsetsView::new(offsets).bridge_err()?,
diskann_quantization::Parallelism::Rayon,
&rng_builder,
&cancelation,
)
.map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))?
.flatten();
Ok(result)
})?;
full_pivot_data.copy_from_slice(&trained);
Ok(())
}
#[inline]
pub fn get_chunk_from_training_data(
train_data: &[f32],
num_train: usize,
raw_vector_dim: usize,
chunk_size: usize,
chunk_offset: usize,
) -> Vec<f32> {
let mut result: Vec<f32> = vec![0.0; num_train * chunk_size];
result
.chunks_mut(chunk_size)
.enumerate()
.for_each(|(chunk_number, result_chunk)| {
let train_data_start = chunk_number * raw_vector_dim + chunk_offset;
let train_data_end = train_data_start + chunk_size;
result_chunk.copy_from_slice(&train_data[train_data_start..train_data_end]);
});
result
}
#[inline]
pub fn move_train_data_by_centroid(
train_data: &mut [f32],
num_points: usize,
dimensions: usize,
centroid: &mut [f32],
) {
assert_eq!(train_data.len(), num_points * dimensions);
assert_eq!(centroid.len(), dimensions);
centroid.fill(0.0);
for row in train_data.chunks_exact_mut(dimensions) {
for (c, r) in std::iter::zip(centroid.iter_mut(), row.iter()) {
*c += *r;
}
}
centroid.iter_mut().for_each(|c| *c /= num_points as f32);
for row in train_data.chunks_exact_mut(dimensions) {
for (r, c) in std::iter::zip(row.iter_mut(), centroid.iter()) {
*r -= *c;
}
}
}
#[inline]
pub fn calculate_chunk_offsets(dimensions: usize, num_pq_chunks: usize, offsets: &mut [usize]) {
let mut chunk_offset: usize = 0;
offsets[0] = chunk_offset;
for chunk_index in 0..num_pq_chunks {
chunk_offset += dimensions / num_pq_chunks;
if chunk_index < (dimensions % num_pq_chunks) {
chunk_offset += 1;
}
offsets[chunk_index + 1] = chunk_offset;
}
}
pub fn calculate_chunk_offsets_auto(dimensions: usize, num_pq_chunks: usize) -> Vec<usize> {
let mut offsets = vec![0; num_pq_chunks + 1];
calculate_chunk_offsets(dimensions, num_pq_chunks, offsets.as_mut_slice());
offsets
}
pub fn accum_row_inplace<T>(mut x: MutMatrixView<T>, y: &[T])
where
T: Copy + std::ops::AddAssign,
{
assert_eq!(x.ncols(), y.len());
x.row_iter_mut().for_each(|row| {
std::iter::zip(row.iter_mut(), y.iter()).for_each(|(a, b)| {
*a += *b;
});
});
}
pub fn generate_pq_data_from_pivots<T, Storage>(
num_centers: usize,
num_pq_chunks: usize,
pq_storage: &mut PQStorage,
storage_provider: &Storage,
offset: usize,
pool: RayonThreadPoolRef<'_>,
) -> ANNResult<()>
where
T: Copy + VectorRepr,
Storage: StorageWriteProvider + StorageReadProvider,
{
let timer = Timer::new();
info!("Generating PQ data starting from offset {}", offset);
let uncompressed_data_reader =
&mut storage_provider.open_reader(pq_storage.get_data_path()?)?;
let mut compressed_data_writer = if offset > 0 {
storage_provider.open_writer(pq_storage.get_compressed_data_path())?
} else {
storage_provider.create_for_write(pq_storage.get_compressed_data_path())?
};
let (num_points, dim) = Metadata::read(uncompressed_data_reader)?.into_dims();
let mut full_pivot_data: Vec<f32>;
let centroid: Vec<f32>;
let chunk_offsets: Vec<usize>;
let full_dim: usize;
if !pq_storage.pivot_data_exist(storage_provider) {
return Err(ANNError::log_pq_error(
"ERROR: PQ k-means pivot file not found.",
));
} else {
(_, full_dim) = pq_storage.read_existing_pivot_metadata(storage_provider)?;
(full_pivot_data, centroid, chunk_offsets) = pq_storage.load_existing_pivot_data(
&num_pq_chunks,
&num_centers,
&full_dim,
storage_provider,
)?;
}
let mut full_pivot_data_mat =
MutMatrixView::try_from(full_pivot_data.as_mut_slice(), num_centers, full_dim)
.bridge_err()?;
accum_row_inplace(full_pivot_data_mat.as_mut_view(), centroid.as_slice());
pq_storage.write_compressed_pivot_metadata::<Storage>(
num_points,
num_pq_chunks,
&mut compressed_data_writer,
)?;
const CHUNKING_BLOCK_SIZE: usize = 10_000;
let block_size = if num_points <= CHUNKING_BLOCK_SIZE {
num_points
} else {
CHUNKING_BLOCK_SIZE
};
let num_points_to_compress = num_points - offset;
let num_blocks = (num_points_to_compress / block_size)
+ !num_points_to_compress.is_multiple_of(block_size) as usize;
uncompressed_data_reader.seek(SeekFrom::Start(
(size_of::<i32>() * 2 + offset * dim * size_of::<T>()) as u64,
))?;
let table = TransposedTable::from_parts(
full_pivot_data_mat.as_view(),
diskann_quantization::views::ChunkOffsetsView::new(&chunk_offsets)
.bridge_err()?
.to_owned(),
)
.map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))?;
let mut buffer = vec![0.0; full_dim * block_size];
for block_index in 0..num_blocks {
let start_index: usize = offset + block_index * block_size;
let end_index: usize = std::cmp::min(start_index + block_size, num_points);
let cur_block_size: usize = end_index - start_index;
let mut block_compressed_base: Vec<u8> = vec![0; cur_block_size * num_pq_chunks];
let block_data: Vec<T> = read_exact_into(uncompressed_data_reader, cur_block_size * dim)?;
for (dst, src) in buffer
.chunks_exact_mut(full_dim)
.zip(block_data.chunks_exact(dim))
{
T::as_f32_into(src, dst).into_ann_result()?;
}
let block_data = &buffer[..cur_block_size * full_dim];
const BATCH_SIZE: usize = 128;
let mut compressed_block =
MutMatrixView::try_from(&mut block_compressed_base, cur_block_size, num_pq_chunks)
.bridge_err()?;
let base_block = MatrixView::try_from(block_data, cur_block_size, full_dim).bridge_err()?;
base_block
.par_window_iter(BATCH_SIZE)
.zip_eq(compressed_block.par_window_iter_mut(BATCH_SIZE))
.try_for_each_in_pool(pool, |(src, dst)| {
table.compress_into(src, dst).map_err(|err| {
ANNError::log_pq_error(diskann_quantization::error::format(&err))
})
})?;
let offset = start_index * num_pq_chunks + std::mem::size_of::<i32>() * 2;
compressed_data_writer.seek(SeekFrom::Start(offset as u64))?;
compressed_data_writer.write_all(&block_compressed_base)?;
}
info!(
"PQ data generation took {} seconds",
timer.elapsed().as_secs_f64()
);
Ok(())
}
pub fn generate_pq_data_from_pivots_from_membuf<T: Copy + Into<f32>>(
vector_data: &[T],
pivot_data: &[f32],
num_pivots: usize,
centroid: Option<&[f32]>,
offsets: &[usize],
pq_out: &mut [u8],
) -> ANNResult<()> {
let dim = vector_data.len();
let table = BasicTableView::new(
MatrixView::try_from(pivot_data, num_pivots, dim).bridge_err()?,
diskann_quantization::views::ChunkOffsetsView::new(offsets).bridge_err()?,
)
.map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))?;
let mut data = vector_data
.iter()
.map(|x| (*x).into())
.collect::<Vec<f32>>();
centroid.map_or(Ok(()), |centroid_unwrapped| -> ANNResult<()> {
if centroid_unwrapped.len() != vector_data.len() {
return Err(ANNError::log_pq_error(
"Error: centroids vector size does not match dimension!",
));
}
for (dim_index, item) in data.iter_mut().enumerate() {
*item -= centroid_unwrapped[dim_index];
}
Ok(())
})?;
table
.compress_into(data.as_slice(), pq_out)
.map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))
}
pub fn generate_pq_data_from_pivots_from_membuf_batch<T: Copy + Sync + Into<f32>>(
parameters: &GeneratePivotArguments,
vector_data: &[T],
pivot_data: &[f32],
centroid: &[f32],
offsets: &[usize],
pq_out: &mut [u8],
pool: RayonThreadPoolRef<'_>,
) -> ANNResult<()> {
let num_train = parameters.num_train();
let num_pq_chunks = parameters.num_pq_chunks();
let dim = parameters.dim();
if vector_data.len() != num_train * dim {
return Err(ANNError::log_pq_error(
"Error: Vector data length has the incorrect size!",
));
}
if pq_out.len() != num_train * num_pq_chunks {
return Err(ANNError::log_pq_error(
"Error: Invalid PQ buffer input size.",
));
}
let translate_to_center = parameters.translate_to_center();
let centroid_option: Option<&[f32]> = translate_to_center.then_some(centroid);
pq_out
.par_chunks_mut(num_pq_chunks)
.zip(vector_data.par_chunks(dim))
.try_for_each_in_pool(pool, |(pq_slice, vector_slice)| {
generate_pq_data_from_pivots_from_membuf(
vector_slice,
pivot_data,
parameters.num_centers(),
centroid_option,
offsets,
pq_slice,
)
})
}
#[cfg(test)]
mod pq_test {
use std::{f32, io::Write};
use crate::storage::VirtualStorageProvider;
use approx::assert_relative_eq;
use diskann::utils::IntoUsize;
use diskann_utils::test_data_root;
use rand_distr::{Distribution, Uniform};
use rstest::rstest;
use vfs::OverlayFS;
use super::*;
use crate::{
model::{
FixedChunkPQTable,
pq::{METADATA_SIZE, debug},
},
utils::{ParallelIteratorInPool, create_thread_pool_for_test, read_bin_from},
};
#[test]
fn test_move_train_data_by_centroid() {
let dim = 20;
let num_data = 200;
let val: f32 = 1.0;
let mut data = vec![val; dim * num_data];
let mut centroid = vec![0.0; dim];
move_train_data_by_centroid(&mut data, num_data, dim, &mut centroid);
assert!(centroid.iter().all(|&i| i == val));
assert!(data.iter().all(|&i| i == 0.0));
}
#[test]
fn generate_pq_pivots_test() {
let storage_provider = VirtualStorageProvider::new_memory();
let pivot_file_name = "/generate_pq_pivots_test3.bin";
let compressed_file_name = "/compressed2.bin";
let pq_training_file_name = "/file_not_used.bin";
let pq_storage: PQStorage = PQStorage::new(
pivot_file_name,
compressed_file_name,
Some(pq_training_file_name),
);
let mut train_data: Vec<f32> = vec![
1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32,
2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32,
2.1f32, 2.1f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32,
100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32,
];
let pool = create_thread_pool_for_test();
generate_pq_pivots(
GeneratePivotArguments::new(5, 8, 2, 2, 5, true).unwrap(),
&mut train_data,
&pq_storage,
&storage_provider,
crate::utils::create_rnd_provider_from_seed_in_tests(42),
pool.as_ref(),
)
.unwrap();
let mut reader = storage_provider.open_reader(pivot_file_name).unwrap();
let offsets = read_bin_from::<u64>(&mut reader, 0).unwrap();
let file_offset_data = offsets.map(|x| x.into_usize());
assert_eq!(file_offset_data[(0, 0)], METADATA_SIZE);
assert_eq!(offsets.nrows(), 4);
assert_eq!(offsets.ncols(), 1);
let pivots = read_bin_from::<f32>(&mut reader, file_offset_data[(0, 0)]).unwrap();
assert_eq!(pivots.as_slice().len(), 16);
assert_eq!(pivots.nrows(), 2);
assert_eq!(pivots.ncols(), 8);
let centroid = read_bin_from::<f32>(&mut reader, file_offset_data[(1, 0)]).unwrap();
assert_eq!(
centroid[(0, 0)],
(1.0f32 + 2.0f32 + 2.1f32 + 2.2f32 + 100.0f32) / 5.0f32
);
assert_eq!(centroid.nrows(), 8);
assert_eq!(centroid.ncols(), 1);
let chunk_offsets = read_bin_from::<u32>(&mut reader, file_offset_data[(2, 0)])
.unwrap()
.map(|x| x.into_usize());
assert_eq!(chunk_offsets[(0, 0)], 0);
assert_eq!(chunk_offsets[(1, 0)], 4);
assert_eq!(chunk_offsets[(2, 0)], 8);
assert_eq!(chunk_offsets.nrows(), 3);
assert_eq!(chunk_offsets.ncols(), 1);
}
#[rstest]
#[case(false, 2)]
#[case(true, 2)]
#[case(false, 3)]
#[case(true, 3)]
fn generate_pq_pivots_membuf_test(#[case] make_zero_mean: bool, #[case] num_pq_chunks: usize) {
let num_train = 5;
let dim = 8;
let num_centers = 2;
let train_data: [f32; 40] = [
1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32,
2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32,
2.1f32, 2.1f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32,
100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32,
];
let mut full_pivot_data: Vec<f32> = vec![0.0; num_centers * dim];
let mut centroids: Vec<f32> = vec![0.0; dim];
let mut offsets: Vec<usize> = vec![0; num_pq_chunks + 1];
let pool = create_thread_pool_for_test();
let result = generate_pq_pivots_from_membuf(
&GeneratePivotArguments::new(
num_train,
dim,
num_centers,
num_pq_chunks,
5,
make_zero_mean,
)
.unwrap(),
&train_data, &mut centroids,
&mut offsets,
&mut full_pivot_data,
&mut crate::utils::create_rnd_in_tests(),
&mut (false),
pool.as_ref(),
);
assert!(result.is_ok());
assert_eq!(full_pivot_data.len(), 16);
}
#[test]
fn read_pivot_metadata_existing_test() {
const DATA_FILE: &str = "/test/test/fake.bin";
const PQ_PIVOT_PATH: &str = "/sift/siftsmall_learn_pq_pivots.bin";
const PQ_COMPRESSED_PATH: &str = "/test/test/fake.bin";
let mut train_data = vec![0.0; 10 * 5];
let num_train = 10;
let dim = 128;
let num_centers = 256;
let num_pq_chunks = dim - 1;
let max_k_means_reps = 10;
let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
let pq_storage = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, Some(DATA_FILE));
let pool = create_thread_pool_for_test();
let result = generate_pq_pivots(
GeneratePivotArguments::new(
num_train,
dim,
num_centers,
num_pq_chunks,
max_k_means_reps,
true,
)
.unwrap(),
&mut train_data,
&pq_storage,
&storage_provider,
crate::utils::create_rnd_provider_from_seed_in_tests(42),
pool.as_ref(),
);
assert!(result.is_ok());
}
#[test]
fn generate_pq_data_from_pivots_test() {
let storage_provider = VirtualStorageProvider::new_memory();
let data_file = "/generate_pq_data_from_pivots_test_data.bin";
let mut train_data: Vec<f32> = vec![
1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32,
2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32,
2.1f32, 2.1f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32,
100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32,
];
let my_nums_unstructured: &[u8] = bytemuck::must_cast_slice(&train_data);
let meta: Vec<i32> = vec![5, 8];
let meta_unstructured: &[u8] = bytemuck::must_cast_slice(&meta);
{
let mut data_file_writer = storage_provider.create_for_write(data_file).unwrap();
data_file_writer
.write_all(meta_unstructured)
.expect("Failed to write sample file");
data_file_writer
.write_all(my_nums_unstructured)
.expect("Failed to write sample file");
}
let pq_pivots_path = "/generate_pq_data_from_pivots_test_pivot.bin";
let pq_compressed_vectors_path = "/generate_pq_data_from_pivots_test.bin";
let mut pq_storage =
PQStorage::new(pq_pivots_path, pq_compressed_vectors_path, Some(data_file));
let pool = create_thread_pool_for_test();
generate_pq_pivots(
GeneratePivotArguments::new(5, 8, 2, 2, 5, true).unwrap(),
&mut train_data,
&pq_storage,
&storage_provider,
crate::utils::create_rnd_provider_from_seed_in_tests(42),
pool.as_ref(),
)
.unwrap();
generate_pq_data_from_pivots::<f32, _>(
2,
2,
&mut pq_storage,
&storage_provider,
0,
pool.as_ref(),
)
.unwrap();
let compressed = read_bin_from::<u8>(
&mut storage_provider
.open_reader(pq_compressed_vectors_path)
.unwrap(),
0,
)
.unwrap();
assert_eq!(compressed.nrows(), 5);
assert_eq!(compressed.ncols(), 2);
assert_eq!(compressed[(0, 0)], compressed[(1, 0)]);
assert_ne!(compressed[(0, 0)], compressed[(4, 0)]);
storage_provider.delete(data_file).unwrap();
storage_provider.delete(pq_pivots_path).unwrap();
storage_provider.delete(pq_compressed_vectors_path).unwrap();
}
#[rstest]
#[case(false, 2)]
#[case(true, 2)]
#[case(false, 3)]
#[case(true, 3)]
fn generate_pq_data_from_pivots_membuf_test(
#[case] make_zero_mean: bool,
#[case] num_pq_chunks: usize,
) {
let num_train: usize = 5;
let dim: usize = 8;
let num_centers: usize = 2;
let max_k_means_reps: usize = 5;
let train_data: Vec<f32> = vec![
1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32,
2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32,
2.1f32, 2.1f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32,
100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32,
];
let mut centroids: Vec<f32> = vec![f32::MAX; dim];
let mut offsets: Vec<usize> = vec![usize::MAX; num_pq_chunks + 1];
let mut pivot_data: Vec<f32> = vec![f32::MAX; num_centers * dim];
let pool = create_thread_pool_for_test();
generate_pq_pivots_from_membuf(
&GeneratePivotArguments::new(
num_train,
dim,
num_centers,
num_pq_chunks,
max_k_means_reps,
make_zero_mean,
)
.unwrap(),
&train_data,
&mut centroids,
&mut offsets,
&mut pivot_data,
&mut crate::utils::create_rnd_in_tests(),
&mut (false),
pool.as_ref(),
)
.unwrap();
let mut pq: Vec<u8> = vec![0; num_pq_chunks];
for i in 0..num_train {
generate_pq_data_from_pivots_from_membuf(
&train_data[dim * i..dim * (i + 1)],
&pivot_data,
num_centers,
make_zero_mean.then_some(¢roids),
&offsets,
&mut pq,
)
.unwrap();
}
assert!(
!centroids.contains(&f32::MAX),
"centroids contains max value!"
);
assert!(
!offsets.contains(&usize::MAX),
"offsets contains max value!"
);
assert!(
!pivot_data.contains(&f32::MAX),
"pivot_data contains max value!"
);
if !make_zero_mean {
assert!(
centroids.iter().all(|&x| x == 0.0),
"centroids is not all 0"
);
}
}
#[rstest]
#[case(true, 16)]
#[case(true, 32)]
#[case(true, 17)]
#[case(true, 13)]
fn verify_identical_results_for_membuf_api(
#[case] make_zero_mean: bool,
#[case] num_pq_chunks: usize,
) {
let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
let data_file = "/sift/siftsmall_learn.bin";
let pq_pivots_path = "/pq_pivots_validation.bin";
let pq_compressed_vectors_path = "/pq_validation.bin";
let mut pq_storage: PQStorage =
PQStorage::new(pq_pivots_path, pq_compressed_vectors_path, Some(data_file));
let pool = create_thread_pool_for_test();
let (mut full_data_vector, num_train, train_dim) = pq_storage
.get_random_train_data_slice::<f32, VirtualStorageProvider<OverlayFS>>(
1.0,
&storage_provider,
&mut crate::utils::create_rnd_in_tests(),
)
.unwrap();
generate_pq_pivots(
GeneratePivotArguments::new(
num_train,
train_dim,
NUM_PQ_CENTROIDS,
num_pq_chunks,
NUM_KMEANS_REPS_PQ,
false,
)
.expect("Failed to create pivot parameters"),
&mut full_data_vector,
&pq_storage,
&storage_provider,
crate::utils::create_rnd_provider_from_seed_in_tests(42),
pool.as_ref(),
)
.expect("Failed to generate pivots");
generate_pq_data_from_pivots::<f32, _>(
NUM_PQ_CENTROIDS,
num_pq_chunks,
&mut pq_storage,
&storage_provider,
0,
pool.as_ref(),
)
.expect("Failed to generate quantized data");
let (full_pivot_data, centroid, offsets) = pq_storage
.load_existing_pivot_data(
&num_pq_chunks,
&NUM_PQ_CENTROIDS,
&train_dim,
&storage_provider,
)
.unwrap();
let mut membuf_pq_data: Vec<u8> = vec![0; num_pq_chunks * num_train];
membuf_pq_data
.par_chunks_mut(num_pq_chunks)
.enumerate()
.for_each_in_pool(pool.as_ref(), |(i, membuf_slice)| {
generate_pq_data_from_pivots_from_membuf(
&full_data_vector[train_dim * i..train_dim * (i + 1)],
&full_pivot_data,
NUM_PQ_CENTROIDS,
make_zero_mean.then_some(¢roid),
&offsets,
membuf_slice,
)
.unwrap();
});
let original_pq_data = read_bin_from::<u8>(
&mut storage_provider
.open_reader(pq_compressed_vectors_path)
.unwrap(),
0,
)
.unwrap();
let membuf_view =
MatrixView::try_from(membuf_pq_data.as_slice(), num_train, num_pq_chunks).unwrap();
let original_view =
MatrixView::try_from(original_pq_data.as_slice(), num_train, num_pq_chunks).unwrap();
let mut offsets = vec![0; num_pq_chunks + 1];
calculate_chunk_offsets(train_dim, num_pq_chunks, &mut offsets);
let offset_view = diskann_quantization::views::ChunkOffsetsView::new(&offsets).unwrap();
let full_data =
MatrixView::try_from(full_data_vector.as_slice(), num_train, train_dim).unwrap();
let pivot_view =
MatrixView::try_from(full_pivot_data.as_slice(), NUM_PQ_CENTROIDS, train_dim).unwrap();
let max_relative_error = 2.05e-5;
let max_mismatches = 6;
let mismatch_records = debug::compare_pq(
full_data,
offset_view,
pivot_view,
¢roid,
membuf_view,
original_view,
);
let mut max_relative_error_seen: f32 = 0.0;
mismatch_records.iter().enumerate().for_each(|(i, r)| {
println!("Mismatch {} of {}\n", i + 1, mismatch_records.len());
println!("{}", r);
let relative_error = (r.squared_l2_a - r.squared_l2_b).abs() / (r.squared_l2_b);
println!("relative error = {relative_error}");
max_relative_error_seen = max_relative_error_seen.max(relative_error)
});
assert!(
max_relative_error_seen <= max_relative_error,
"observed max relative error {max_relative_error_seen} exceeds the configured \
upper bound of {max_relative_error}"
);
assert!(
mismatch_records.len() <= max_mismatches,
"observed {} mismatches when a maximum of {} was allowed",
mismatch_records.len(),
max_mismatches
);
}
enum RandGenStrategy {
HundredMaxMin,
ZeroToHundred,
UnitSphere,
RandDivByRand,
}
#[rstest]
#[case(16, 8)]
#[case(8, 3)]
#[case(7, 5)]
#[case(10, 5)]
#[case(20, 2)]
#[case(3, 3)]
fn check_pq_api_for_membuf_runs_with_rand_f32_data(
#[values(
RandGenStrategy::HundredMaxMin,
RandGenStrategy::ZeroToHundred,
RandGenStrategy::UnitSphere,
RandGenStrategy::RandDivByRand
)]
rand_strategy: RandGenStrategy,
#[values(false, true)] make_zero_mean: bool,
#[values(256)] npts: usize,
#[case] dim: usize,
#[case] num_pq_chunks: usize,
) {
let mut rng = crate::utils::create_rnd_provider_from_seed(42).create_rnd();
let full_data_vector: Vec<f32> = match rand_strategy {
RandGenStrategy::HundredMaxMin => (0..npts * dim)
.map(|_| rng.random_range(-100.0..100.0))
.collect(),
RandGenStrategy::ZeroToHundred => (0..npts * dim)
.map(|_| rng.random_range(0.0..100.0))
.collect(),
RandGenStrategy::UnitSphere => {
let mut data: Vec<f32> = (0..npts * dim)
.map(|_| rng.random_range(-100.0..100.0))
.collect();
let norms: Vec<f32> = data
.chunks(dim)
.map(|v| v.iter().map(|x| x * x).sum::<f32>().sqrt())
.collect();
for (slice, norm) in data.chunks_mut(dim).zip(norms) {
for iter in slice.iter_mut() {
*iter /= norm;
}
}
data
}
RandGenStrategy::RandDivByRand => (0..npts * dim)
.map(|_| rng.random_range(0.0..100.0) / rng.random_range(f32::EPSILON..100.0))
.collect(),
};
let mut full_pivot_data: Vec<f32> = vec![0.0; NUM_PQ_CENTROIDS * dim];
let mut centroids: Vec<f32> = vec![0.0; dim];
let mut offsets: Vec<usize> = vec![0; num_pq_chunks + 1];
let pool = create_thread_pool_for_test();
let result = generate_pq_pivots_from_membuf(
&GeneratePivotArguments::new(
npts,
dim,
NUM_PQ_CENTROIDS,
num_pq_chunks,
crate::model::pq::pq_construction::NUM_KMEANS_REPS_PQ,
make_zero_mean,
)
.unwrap(),
&full_data_vector,
&mut centroids,
&mut offsets,
&mut full_pivot_data,
&mut crate::utils::create_rnd_in_tests(),
&mut (false),
pool.as_ref(),
);
assert!(result.is_ok());
let mut membuf_pq_data: Vec<u8> = vec![0; num_pq_chunks];
for i in 0..npts {
let result = generate_pq_data_from_pivots_from_membuf(
&full_data_vector[(dim * i)..(dim * (i + 1))],
&full_pivot_data,
NUM_PQ_CENTROIDS,
make_zero_mean.then_some(¢roids),
&offsets,
&mut membuf_pq_data,
);
assert!(result.is_ok());
}
}
#[test]
fn pq_end_to_end_validation_with_codebook_test() {
let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
let data_file = "/sift/siftsmall_learn.bin";
let pq_pivots_path = "/sift/siftsmall_learn_pq_pivots.bin";
let ground_truth_path = "/sift/siftsmall_learn_pq_compressed.bin";
let pq_compressed_vectors_path = "/validation.bin";
let mut pq_storage =
PQStorage::new(pq_pivots_path, pq_compressed_vectors_path, Some(data_file));
let pool = create_thread_pool_for_test();
generate_pq_data_from_pivots::<f32, _>(
NUM_PQ_CENTROIDS,
1,
&mut pq_storage,
&storage_provider,
0,
pool.as_ref(),
)
.expect("Failed to generate quantized data");
let data = read_bin_from::<u8>(
&mut storage_provider
.open_reader(pq_compressed_vectors_path)
.unwrap(),
0,
)
.unwrap();
let gt_data = read_bin_from::<u8>(
&mut storage_provider.open_reader(ground_truth_path).unwrap(),
0,
)
.unwrap();
assert_eq!(data, gt_data);
}
#[test]
fn get_chunk_from_training_data_chunk0() {
let train_data = vec![
0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6,
];
let result = get_chunk_from_training_data(
&train_data,
2,
7,
3,
0,
);
assert_eq!(result, vec!(0.0, 0.1, 0.2, 1.0, 1.1, 1.2));
}
#[test]
fn get_chunk_from_training_data_chunk1() {
let train_data = vec![
0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6,
];
let chunk_id = 1;
let chunk_size = 3;
let chunk_offset = chunk_size * chunk_id;
let result = get_chunk_from_training_data(
&train_data,
2,
7,
chunk_size,
chunk_offset,
);
assert_eq!(result, vec!(0.3, 0.4, 0.5, 1.3, 1.4, 1.5));
}
#[rstest]
#[case("l2", 31)]
#[case("l2", 32)]
#[case("inner_product", 31)]
fn rerankingtest_with_membuf_pq_functions(
#[case] distance_function: String,
#[case] num_pq_chunks: usize,
) {
let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
let data_file = "/sift/siftsmall_learn.bin";
let pq_pivots_path = "/pq_pivots_validation.bin";
let pq_compressed_vectors_path = "/pq_validation.bin";
let pq_storage: PQStorage =
PQStorage::new(pq_pivots_path, pq_compressed_vectors_path, Some(data_file));
let num_runs = 1;
let num_closest_pq_vectors = 100;
let num_closest_gt_vectors = 10;
let p_val = 0.1;
let (train_data_vector, train_size, train_dim) = pq_storage
.get_random_train_data_slice::<f32, VirtualStorageProvider<OverlayFS>>(
p_val,
&storage_provider,
&mut crate::utils::create_rnd_in_tests(),
)
.unwrap();
let mut full_pivot_data: Vec<f32> = vec![0.0; NUM_PQ_CENTROIDS * train_dim];
let mut centroid: Vec<f32> = vec![0.0; train_dim];
let mut offsets: Vec<usize> = vec![0; num_pq_chunks + 1];
let pivot_args = GeneratePivotArguments::new(
train_size,
train_dim,
NUM_PQ_CENTROIDS,
num_pq_chunks,
crate::model::pq::pq_construction::NUM_KMEANS_REPS_PQ,
false,
)
.unwrap();
let pool = create_thread_pool_for_test();
generate_pq_pivots_from_membuf(
&pivot_args,
&train_data_vector,
&mut centroid,
&mut offsets,
&mut full_pivot_data,
&mut crate::utils::create_rnd_in_tests(),
&mut (false),
pool.as_ref(),
)
.unwrap();
let (mut full_data_vector, train_size, train_dim) = pq_storage
.get_random_train_data_slice::<f32, VirtualStorageProvider<OverlayFS>>(
1.0,
&storage_provider,
&mut crate::utils::create_rnd_in_tests(),
)
.unwrap();
let pivot_args = GeneratePivotArguments::new(
train_size,
pivot_args.dim(),
pivot_args.num_centers(),
pivot_args.num_pq_chunks(),
pivot_args.max_k_means_reps(),
pivot_args.translate_to_center(),
)
.unwrap();
let pool = create_thread_pool_for_test();
let mut pq_data: Vec<u8> = vec![0; num_pq_chunks * train_size];
generate_pq_data_from_pivots_from_membuf_batch(
&pivot_args,
&full_data_vector,
&full_pivot_data,
¢roid,
&offsets,
&mut pq_data,
pool.as_ref(),
)
.unwrap();
let fixed_chunk_pq_table = FixedChunkPQTable::new(
train_dim,
full_pivot_data.into(),
centroid.clone().into(),
offsets.into(),
)
.unwrap();
let pairs = [(0, 1), (1, 0), (10, 10), (23, 42)];
for (a, b) in pairs {
let left = &pq_data[a * num_pq_chunks..(a + 1) * num_pq_chunks];
let right = &pq_data[b * num_pq_chunks..(b + 1) * num_pq_chunks];
let self_l2 = fixed_chunk_pq_table.qq_l2_distance(left, right);
let mut inflated = fixed_chunk_pq_table.inflate_vector(left);
fixed_chunk_pq_table.preprocess_query(&mut inflated);
let from_inflated = fixed_chunk_pq_table.l2_distance(&inflated, right);
assert_relative_eq!(self_l2, from_inflated, max_relative = 1e-6);
}
let mut rng = crate::utils::create_rnd_in_tests();
let int_distribution = Uniform::try_from(0..train_size).unwrap();
let mut counter_sum = vec![0; num_runs];
for item in counter_sum.iter_mut().take(num_runs) {
let query_index = int_distribution.sample(&mut rng);
let mut query_vec =
full_data_vector[train_dim * query_index..train_dim * (query_index + 1)].to_vec();
let query = query_vec.as_mut_slice();
let mut distance_map: Vec<(f32, usize)> = Vec::new();
for i in 0..train_size {
if i == query_index {
continue;
}
let compressed_data = pq_data[i * num_pq_chunks..(i + 1) * num_pq_chunks].to_vec();
match distance_function.as_str() {
"l2" => {
let distance = fixed_chunk_pq_table.l2_distance(query, &compressed_data);
distance_map.push((distance, i));
}
"inner_product" => {
let distance = fixed_chunk_pq_table.inner_product(query, &compressed_data);
distance_map.push((distance, i));
}
_ => panic!("Invalid distance function"),
}
}
distance_map.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let closest_pq_vectors: Vec<usize> = distance_map
.into_iter()
.take(num_closest_pq_vectors)
.map(|(_, value)| value)
.collect();
let mut gt_map: Vec<(f32, usize)> = Vec::new();
for i in 0..train_size {
if i == query_index {
continue;
}
let data_vector = &mut full_data_vector[i * train_dim..(i + 1) * train_dim];
let mut distance = 0.0;
match distance_function.as_str() {
"l2" => {
for j in 0..train_dim {
let diff = query[j] - data_vector[j];
distance += diff * diff;
}
gt_map.push((distance, i));
}
"inner_product" => {
for j in 0..train_dim {
distance += query[j] * data_vector[j];
}
gt_map.push((-distance, i));
}
_ => panic!("Invalid distance function"),
}
}
gt_map.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let closest_gt_vectors: Vec<usize> = gt_map
.into_iter()
.take(num_closest_gt_vectors)
.map(|(_, value)| value)
.collect();
let counter = closest_gt_vectors
.iter()
.filter(|&point| closest_pq_vectors.contains(point))
.count();
*item = counter;
}
let recall_percentage: f32 = ((counter_sum.iter().sum::<usize>() as f32 / num_runs as f32)
/ num_closest_gt_vectors as f32)
* 100.0;
println!(
"\n\nOriginal data dimension: {}, Number of PQ chunks: {}",
train_dim, num_pq_chunks
);
println!(
"Data file: {}, Distance function: {}, Recall: {}",
data_file, distance_function, recall_percentage
);
assert!(recall_percentage > 90.0);
}
}