use std::io::{Seek, SeekFrom, Write};
use super::{StorageReadProvider, StorageWriteProvider};
use diskann::{
ANNError, ANNResult,
utils::{IntoUsize, VectorRepr},
};
use diskann_utils::{
io::{Metadata, read_bin, write_bin},
views::{Matrix, MatrixView},
};
use rand::Rng;
use tracing::info;
use crate::{
model::{FixedChunkPQTable, NUM_PQ_CENTROIDS, pq::METADATA_SIZE},
utils::{gen_random_slice, read_bin_from, write_bin_from},
};
type FullPivotDataType = Vec<f32>;
type CentroidType = Vec<f32>;
type ChunkOffsetsType = Vec<usize>;
#[derive(Debug, Clone)]
pub struct PQStorage {
pivot_data_path: String,
compressed_data_path: String,
data_path: Option<String>,
}
impl PQStorage {
pub fn new(pivot_data_path: &str, compressed_data_path: &str, data_path: Option<&str>) -> Self {
Self {
pivot_data_path: pivot_data_path.to_string(),
compressed_data_path: compressed_data_path.to_string(),
data_path: data_path.map(|x| x.to_string()),
}
}
pub fn write_compressed_pivot_metadata<Storage>(
&self,
npts: usize,
pq_chunk: usize,
writer: &mut Storage::Writer,
) -> ANNResult<()>
where
Storage: StorageWriteProvider,
{
Metadata::new(npts, pq_chunk)?.write(writer)?;
Ok(())
}
pub fn write_pivot_data<Storage>(
&self,
full_pivot_data: &[f32],
centroid: &[f32],
chunk_offsets: &[usize],
num_centers: usize,
dim: usize,
storage_provider: &Storage,
) -> ANNResult<()>
where
Storage: StorageWriteProvider,
{
let mut cumul_bytes: Vec<usize> = vec![0; 4];
cumul_bytes[0] = METADATA_SIZE;
let writer = &mut storage_provider.create_for_write(&self.pivot_data_path)?;
writer.seek(SeekFrom::Start(cumul_bytes[0] as u64))?;
let pivot_view = MatrixView::try_from(full_pivot_data, num_centers, dim)?;
cumul_bytes[1] = cumul_bytes[0] + write_bin(pivot_view, writer)?;
cumul_bytes[2] = cumul_bytes[1] + write_bin(MatrixView::column_vector(centroid), writer)?;
let chunk_offsets_u32: Vec<u32> = chunk_offsets.iter().map(|&x| x as u32).collect();
cumul_bytes[3] = cumul_bytes[2]
+ write_bin(
MatrixView::column_vector(chunk_offsets_u32.as_slice()),
writer,
)?;
let cumul_bytes_u64: Vec<u64> = cumul_bytes.iter().map(|&x| x as u64).collect();
write_bin_from(
MatrixView::column_vector(cumul_bytes_u64.as_slice()),
writer,
0,
)?;
writer.flush()?;
Ok(())
}
pub fn pivot_data_exist<Storage>(&self, storage_provider: &Storage) -> bool
where
Storage: StorageReadProvider,
{
storage_provider.exists(&self.pivot_data_path)
}
pub fn read_existing_pivot_metadata<Storage>(
&self,
storage_provider: &Storage,
) -> std::io::Result<(usize, usize)>
where
Storage: StorageReadProvider,
{
let reader = &mut storage_provider.open_reader(&self.pivot_data_path)?;
reader.seek(SeekFrom::Start(METADATA_SIZE as u64))?;
Ok(Metadata::read(reader)?.into_dims())
}
pub fn load_existing_pivot_data<Storage>(
&self,
num_pq_chunks: &usize,
num_centers: &usize,
dim: &usize,
storage_provider: &Storage,
) -> ANNResult<(FullPivotDataType, CentroidType, ChunkOffsetsType)>
where
Storage: StorageReadProvider,
{
let reader = &mut storage_provider.open_reader(&self.pivot_data_path)?;
let offsets = read_bin_from::<u64>(reader, 0)?;
if offsets.nrows() != 4 {
return Err(ANNError::log_pq_error(format_args!(
"Error reading pq_pivots file {}. Offsets don't contain correct \
metadata, # offsets = {}, but expecting 4.",
&self.pivot_data_path,
offsets.nrows()
)));
}
let file_offset_data = offsets.map(|x| x.into_usize());
info!(" Offset data: {:?}", file_offset_data.as_slice());
let pivots = read_bin_from::<f32>(reader, file_offset_data[(0, 0)])?;
if pivots.nrows() != *num_centers || pivots.ncols() != *dim {
return Err(ANNError::log_pq_error(format_args!(
"Error reading pq_pivots file {}. file_num_centers = {}, \
file_dim = {} but expecting {} centers in {} dimensions.",
&self.pivot_data_path,
pivots.nrows(),
pivots.ncols(),
num_centers,
dim
)));
}
let centroid_m = read_bin_from::<f32>(reader, file_offset_data[(1, 0)])?;
if centroid_m.nrows() != *dim || centroid_m.ncols() != 1 {
return Err(ANNError::log_pq_error(format_args!(
"Error reading pq_pivots file {}. file_dim = {}, \
file_cols = {} but expecting {} entries in 1 dimension.",
&self.pivot_data_path,
centroid_m.nrows(),
centroid_m.ncols(),
dim
)));
}
let chunk_offsets_m = read_bin_from::<u32>(reader, file_offset_data[(2, 0)])?;
if chunk_offsets_m.nrows() != *num_pq_chunks + 1 || chunk_offsets_m.ncols() != 1 {
return Err(ANNError::log_pq_error(format_args!(
"Error reading pq_pivots file at chunk offsets; \
file has nr={}, nc={} but expecting nr={} and nc=1.",
chunk_offsets_m.nrows(),
chunk_offsets_m.ncols(),
num_pq_chunks + 1
)));
}
let chunk_offsets = chunk_offsets_m.map(|x| x.into_usize());
Ok((
pivots.into_inner().into_vec(),
centroid_m.into_inner().into_vec(),
chunk_offsets.into_inner().into_vec(),
))
}
pub fn load_pq_compressed_vectors_bin<Storage: StorageReadProvider>(
pq_compressed_data: &str,
num_points_to_load: usize,
num_pq_chunks: usize,
storage_provider: &Storage,
) -> ANNResult<Matrix<u8>> {
info!(
"Loading compressed from pq compressed data file {}...",
pq_compressed_data,
);
info!(
"# of Points: {} , # PQ chunks: {} ",
num_points_to_load, num_pq_chunks
);
let data = read_bin::<u8>(&mut storage_provider.open_reader(pq_compressed_data)?)?;
if data.nrows() != num_points_to_load || data.ncols() != num_pq_chunks {
return Err(ANNError::log_pq_error(format_args!(
"PQ compressed data mismatch: file has {}x{} but expected {}x{}",
data.nrows(),
data.ncols(),
num_points_to_load,
num_pq_chunks
)));
}
info!("PQ compressed dataset loaded.");
Ok(data)
}
pub fn load_pq_pivots_bin<Storage: StorageReadProvider>(
&self,
pq_pivots: &str,
num_pq_chunks: usize,
storage_provider: &Storage,
) -> ANNResult<FixedChunkPQTable> {
if !storage_provider.exists(pq_pivots) {
return Err(ANNError::log_pq_error(
"ERROR: PQ k-means pivot file not found.",
));
}
info!("Loading PQ pivots from {}...", pq_pivots);
let mut reader = storage_provider.open_reader(pq_pivots)?;
let offsets = read_bin_from::<u64>(&mut reader, 0)?;
if offsets.nrows() != 4 {
return Err(ANNError::log_pq_error(format_args!(
"Error reading pq_pivots file {}. Offsets don't contain correct metadata, \
# offsets = {}, but expecting 4.",
pq_pivots,
offsets.nrows()
)));
}
let file_offset_data = offsets.map(|x| x.into_usize());
let pivots = read_bin_from::<f32>(&mut reader, file_offset_data[(0, 0)])?;
if pivots.nrows() > NUM_PQ_CENTROIDS {
return Err(ANNError::log_pq_error(format_args!(
"Error reading pq_pivots file {}. file_num_centers = {}, but expecting {} centers.",
pq_pivots,
pivots.nrows(),
NUM_PQ_CENTROIDS
)));
}
let dim = pivots.ncols();
let centroids = read_bin_from::<f32>(&mut reader, file_offset_data[(1, 0)])?;
if centroids.nrows() != dim || centroids.ncols() != 1 {
return Err(ANNError::log_pq_error(format_args!(
"Error reading pq_pivots file {}. file_dim = {}, file_cols = {} \
but expecting {} entries in 1 dimension.",
pq_pivots,
centroids.nrows(),
centroids.ncols(),
dim
)));
}
let chunk_offsets_m = read_bin_from::<u32>(&mut reader, file_offset_data[(2, 0)])?;
if (chunk_offsets_m.nrows() != num_pq_chunks + 1 && num_pq_chunks as u32 != 0)
|| chunk_offsets_m.ncols() != 1
{
return Err(ANNError::log_pq_error(format_args!(
"Error reading pq_pivots file at chunk offsets; file has nr={}, nc={} \
but expecting nr={} and nc=1. The expected num_pq_chunks should be \
passed as 0 if we want to infer.",
chunk_offsets_m.nrows(),
chunk_offsets_m.ncols(),
num_pq_chunks + 1
)));
}
let chunk_offsets = chunk_offsets_m.map(|x| x.into_usize());
FixedChunkPQTable::new(
dim,
pivots.into_inner(),
centroids.into_inner(),
chunk_offsets.into_inner(),
)
}
pub fn get_random_train_data_slice<T: VectorRepr, Storage>(
&self,
p_val: f64,
storage_provider: &Storage,
generator: &mut impl Rng,
) -> ANNResult<(Vec<f32>, usize, usize)>
where
Storage: StorageReadProvider,
{
gen_random_slice::<T, _>(self.get_data_path()?, p_val, storage_provider, generator)
}
pub fn get_data_path(&self) -> ANNResult<&str> {
self.data_path
.as_ref()
.ok_or_else(|| {
ANNError::log_index_config_error(
"data_path".to_string(),
"pq_storage.data_path is not defined".to_string(),
)
})
.map(|s| s.as_str())
}
pub fn get_compressed_data_path(&self) -> &str {
&self.compressed_data_path
}
}
#[cfg(test)]
mod pq_storage_tests {
use crate::storage::VirtualStorageProvider;
use diskann_utils::test_data_root;
use vfs::MemoryFS;
use super::*;
use crate::utils::gen_random_slice;
const DATA_FILE: &str = "/sift/siftsmall_learn.bin";
const PQ_PIVOT_PATH: &str = "/sift/siftsmall_learn_pq_pivots.bin";
const PQ_COMPRESSED_PATH: &str = "/sift/empty_pq_compressed.bin";
#[test]
fn new_test() {
PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, Some(DATA_FILE));
}
#[test]
fn write_compressed_pivot_metadata_test() {
let storage_provider = VirtualStorageProvider::new_memory();
let compress_pivot_path = "/write_compressed_pivot_metadata_test.bin";
let result = PQStorage::new(PQ_PIVOT_PATH, compress_pivot_path, Some(DATA_FILE));
{
let mut writer = storage_provider
.create_for_write(compress_pivot_path)
.unwrap();
result
.write_compressed_pivot_metadata::<VirtualStorageProvider<MemoryFS>>(
100,
20,
&mut writer,
)
.unwrap();
}
let mut result_reader = storage_provider.open_reader(compress_pivot_path).unwrap();
let metadata = Metadata::read(&mut result_reader).unwrap();
assert_eq!(metadata.npoints(), 100);
assert_eq!(metadata.ndims(), 20);
storage_provider.delete(compress_pivot_path).unwrap();
}
#[test]
fn pivot_data_exist_test() {
let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
let result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, Some(DATA_FILE));
assert!(result.pivot_data_exist(&storage_provider));
let pivot_path = "not_exist_pivot_path.bin";
let result = PQStorage::new(pivot_path, PQ_COMPRESSED_PATH, Some(DATA_FILE));
assert!(!result.pivot_data_exist(&storage_provider));
}
#[test]
fn read_pivot_metadata_test() {
let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
let result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, Some(DATA_FILE));
let (npt, dim) = result
.read_existing_pivot_metadata(&storage_provider)
.unwrap();
assert_eq!(npt, 256);
assert_eq!(dim, 128);
}
#[test]
fn load_pivot_data_test() {
let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
let result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, Some(DATA_FILE));
let (pq_pivot_data, centroids, chunk_offsets) = result
.load_existing_pivot_data(&1, &256, &128, &storage_provider)
.unwrap();
assert_eq!(pq_pivot_data.len(), 256 * 128);
assert_eq!(centroids.len(), 128);
assert_eq!(chunk_offsets.len(), 2);
}
#[test]
fn gen_random_slice_test() {
let storage_provider = VirtualStorageProvider::new_memory();
let file_name = "/gen_random_slice_test.bin";
let data: [u8; 72] = [
2, 0, 0, 0, 8, 0, 0, 0, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00,
0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40,
0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00,
0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41,
0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41,
];
{
let mut writer = storage_provider.create_for_write(file_name).unwrap();
writer
.write_all(&data)
.expect("Failed to write sample file");
}
let (sampled_vectors, slice_size, ndims) =
gen_random_slice::<f32, VirtualStorageProvider<MemoryFS>>(
file_name,
1f64,
&storage_provider,
&mut crate::utils::create_rnd_in_tests(),
)
.unwrap();
let mut start = 8;
(0..sampled_vectors.len()).for_each(|i| {
assert_eq!(sampled_vectors[i].to_le_bytes(), data[start..start + 4]);
start += 4;
});
assert_eq!(sampled_vectors.len(), 16);
assert_eq!(slice_size, 2);
assert_eq!(ndims, 8);
let (sampled_vectors, slice_size, ndims) =
gen_random_slice::<f32, VirtualStorageProvider<MemoryFS>>(
file_name,
0f64,
&storage_provider,
&mut crate::utils::create_rnd_in_tests(),
)
.unwrap();
assert_eq!(sampled_vectors.len(), 0);
assert_eq!(slice_size, 0);
assert_eq!(ndims, 8);
storage_provider
.delete(file_name)
.expect("Failed to delete file");
}
}