use std::{future::Future, sync::Mutex};
use crate::storage::{StorageReadProvider, StorageWriteProvider};
use diskann::{
ANNError, ANNResult, default_post_processor,
graph::{
glue::{
self, DefaultPostProcessor, ExpandBeam, FilterStartPoints, InsertStrategy, Pipeline,
PruneStrategy, SearchExt, SearchStrategy,
},
workingset,
},
provider::{
Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext,
HasId,
},
utils::{IntoUsize, VectorRepr},
};
use diskann_quantization::{
AsFunctor, CompressInto,
bits::{Representation, Unsigned},
meta::NotCanonical,
scalar::{
CompensatedCosineNormalized, CompensatedIP, CompensatedSquaredL2, CompensatedVector,
CompensatedVectorRef, InputContainsNaN, MeanNormMissing, MutCompensatedVectorRef,
ScalarQuantizer,
},
};
use diskann_utils::{Reborrow, ReborrowMut, future::AsyncFriendly};
use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction, distance::Metric};
use thiserror::Error;
use super::{DefaultProvider, GetFullPrecision, PassThrough, Rerank};
use crate::{
common::IgnoreLockPoison,
model::graph::provider::async_::{
FastMemoryVectorProviderAsync, SimpleNeighborProviderAsync,
common::{
AlignedMemoryVectorStore, CreateVectorStore, NoStore, Quantized, SetElementHelper,
TestCallCount, VectorStore,
},
inmem::{FullPrecisionProvider, FullPrecisionStore},
postprocess::{AsDeletionCheck, DeletionCheck, RemoveDeletedIdsAndCopy},
},
storage::{self, AsyncIndexMetadata, AsyncQuantLoadContext, LoadWith, SaveWith},
};
type CVRef<'a, const NBITS: usize> = CompensatedVectorRef<'a, NBITS>;
#[derive(Clone)]
pub struct WithBits<const NBITS: usize> {
quantizer: ScalarQuantizer,
}
impl<const NBITS: usize> WithBits<NBITS> {
pub fn new(quantizer: ScalarQuantizer) -> Self {
Self { quantizer }
}
}
const WRITE_LOCK_GRANULARITY: usize = 16;
const PREFETCH_DEFAULT: usize = 8;
pub struct SQStore<const NBITS: usize> {
data: AlignedMemoryVectorStore<u8>,
quantizer: ScalarQuantizer,
metric: Metric,
write_locks: Vec<Mutex<()>>,
prefetch_lookahead: usize,
num_get_calls: TestCallCount,
}
impl<const NBITS: usize> SQStore<NBITS>
where
Unsigned: Representation<NBITS>,
{
pub(super) fn new(
quantizer: ScalarQuantizer,
num_vectors: usize,
metric: Metric,
prefetch_lookahead: Option<usize>,
) -> Self {
let write_locks = (0..num_vectors.div_ceil(WRITE_LOCK_GRANULARITY))
.map(|_| Mutex::new(()))
.collect::<Vec<_>>();
let bytes = CVRef::<NBITS>::canonical_bytes(quantizer.dim());
Self {
data: AlignedMemoryVectorStore::with_capacity(num_vectors, bytes),
quantizer,
metric,
write_locks,
num_get_calls: TestCallCount::default(),
prefetch_lookahead: prefetch_lookahead.unwrap_or(PREFETCH_DEFAULT),
}
}
pub(crate) fn prefetch_hint(&self, i: usize) {
let data = unsafe { self.data.get_slice(i) };
diskann_vector::prefetch_hint_max::<4, _>(data);
}
pub(super) fn dim(&self) -> usize {
self.quantizer.dim()
}
pub(super) fn get_vector(&self, i: usize) -> Result<CVRef<'_, NBITS>, SQError> {
self.num_get_calls.increment();
Ok(CVRef::from_canonical_front(
unsafe { self.data.get_slice(i) },
self.dim(),
)?)
}
pub(super) fn set_vector<T>(&self, i: usize, v: &[T]) -> Result<(), SQError>
where
T: VectorRepr,
{
let vf32: &[f32] =
&T::as_f32(v).map_err(|e| SQError::FullPrecisionConversionErr(format!("{:?}", e)))?;
debug_assert!(
vf32.len() == self.dim(),
"vector f32 dimension {} does not match dimension {}",
vf32.len(),
self.dim()
);
let lock_id = i / WRITE_LOCK_GRANULARITY;
let _guard = self.write_locks[lock_id].lock_or_panic();
self.quantizer.compress_into(
vf32,
MutCompensatedVectorRef::<NBITS>::from_canonical_front_mut(
unsafe { self.data.get_mut_slice(i) },
self.dim(),
)?,
)?;
Ok(())
}
pub(crate) unsafe fn set_quant_vector(&self, i: usize, v: &[u8]) -> ANNResult<()> {
let expected_quant_len = CVRef::<NBITS>::canonical_bytes(self.dim());
debug_assert!(
v.len() == expected_quant_len,
"vector length {} does not match dimension {}",
v.len(),
expected_quant_len
);
let lock_id = i / WRITE_LOCK_GRANULARITY;
let _guard = self.write_locks[lock_id].lock_or_panic();
unsafe { self.data.get_mut_slice(i) }.copy_from_slice(v);
Ok(())
}
pub(super) fn distance_computer(&self) -> Result<DistanceComputer, SQError> {
Ok(match self.metric {
Metric::L2 => DistanceComputer::SquaredL2(self.quantizer.as_functor()),
Metric::InnerProduct => DistanceComputer::InnerProduct(self.quantizer.as_functor()),
Metric::CosineNormalized => {
DistanceComputer::CosineNormalized(self.quantizer.as_functor())
}
unsupported_metric => {
return Err(SQError::UnsupportedDistanceMetric(unsupported_metric));
}
})
}
pub(super) fn query_computer<T>(
&self,
query: &[T],
allow_rescale: bool,
) -> Result<QueryComputer<NBITS>, SQError>
where
T: VectorRepr,
{
let mut boxed = CompensatedVector::new_boxed(self.dim());
let q = T::as_f32(query)
.map_err(|e| SQError::FullPrecisionConversionErr(format!("{:?}", e)))?;
if allow_rescale && !matches!(self.metric, Metric::L2 | Metric::CosineNormalized) {
let mut query: Box<[f32]> = q.as_ref().into();
self.quantizer.rescale(&mut query)?;
self.quantizer
.compress_into(&*query, boxed.reborrow_mut())?;
} else {
self.quantizer
.compress_into(q.as_ref(), boxed.reborrow_mut())?;
}
Ok(QueryComputer {
inner: self.distance_computer()?,
query: boxed,
})
}
pub fn prefetch_lookahead(&self) -> usize {
self.prefetch_lookahead
}
}
#[derive(Debug)]
pub enum DistanceComputer {
SquaredL2(CompensatedSquaredL2),
InnerProduct(CompensatedIP),
CosineNormalized(CompensatedCosineNormalized),
}
impl<const NBITS: usize> DistanceFunction<CVRef<'_, NBITS>, CVRef<'_, NBITS>, f32>
for DistanceComputer
where
Unsigned: Representation<NBITS>,
CompensatedSquaredL2: for<'a, 'b> DistanceFunction<
CVRef<'a, NBITS>,
CVRef<'b, NBITS>,
diskann_quantization::distances::Result<f32>,
>,
CompensatedIP: for<'a, 'b> DistanceFunction<
CVRef<'a, NBITS>,
CVRef<'b, NBITS>,
diskann_quantization::distances::Result<f32>,
>,
CompensatedCosineNormalized: for<'a, 'b> DistanceFunction<
CVRef<'a, NBITS>,
CVRef<'b, NBITS>,
diskann_quantization::distances::Result<f32>,
>,
{
#[inline(always)]
fn evaluate_similarity(&self, left: CVRef<'_, NBITS>, right: CVRef<'_, NBITS>) -> f32 {
let r = match self {
DistanceComputer::SquaredL2(f) => f.evaluate_similarity(left, right),
DistanceComputer::InnerProduct(f) => f.evaluate_similarity(left, right),
DistanceComputer::CosineNormalized(f) => f.evaluate_similarity(left, right),
};
r.map_err(|err| err.panic(left.len(), right.len())).unwrap()
}
}
pub struct QueryComputer<const NBITS: usize>
where
Unsigned: Representation<NBITS>,
{
inner: DistanceComputer,
query: CompensatedVector<NBITS>,
}
impl<const NBITS: usize> PreprocessedDistanceFunction<CVRef<'_, NBITS>, f32>
for QueryComputer<NBITS>
where
Unsigned: Representation<NBITS>,
DistanceComputer: for<'a, 'b> DistanceFunction<CVRef<'a, NBITS>, CVRef<'b, NBITS>, f32>,
{
fn evaluate_similarity(&self, changing: CVRef<'_, NBITS>) -> f32 {
self.inner
.evaluate_similarity(self.query.reborrow(), changing)
}
}
impl<const NBITS: usize> CreateVectorStore for WithBits<NBITS>
where
Unsigned: Representation<NBITS>,
{
type Target = SQStore<NBITS>;
fn create(
self,
max_points: usize,
metric: Metric,
prefetch_lookahead: Option<usize>,
) -> Self::Target {
SQStore::new(self.quantizer, max_points, metric, prefetch_lookahead)
}
}
impl<const NBITS: usize> VectorStore for SQStore<NBITS> {
fn total(&self) -> usize {
self.data.max_vectors()
}
fn count_for_get_vector(&self) -> usize {
self.num_get_calls.get()
}
}
impl<T, const NBITS: usize> SetElementHelper<T> for SQStore<NBITS>
where
T: VectorRepr,
Unsigned: Representation<NBITS>,
{
fn set_element(&self, id: &u32, element: &[T]) -> ANNResult<()> {
self.set_vector(id.into_usize(), element)?;
Ok(())
}
}
pub struct QuantAccessor<'a, const NBITS: usize, V, D, Ctx> {
provider: &'a DefaultProvider<V, SQStore<NBITS>, D, Ctx>,
id_buffer: Vec<u32>,
is_search: bool,
}
impl<T, const NBITS: usize, D, Ctx> GetFullPrecision
for QuantAccessor<'_, NBITS, FullPrecisionStore<T>, D, Ctx>
where
T: VectorRepr,
{
type Repr = T;
fn as_full_precision(&self) -> &FastMemoryVectorProviderAsync<T> {
&self.provider.base_vectors
}
}
impl<const NBITS: usize, V, D, Ctx> HasId for QuantAccessor<'_, NBITS, V, D, Ctx> {
type Id = u32;
}
impl<const NBITS: usize, V, D, Ctx> SearchExt for QuantAccessor<'_, NBITS, V, D, Ctx>
where
V: AsyncFriendly,
D: AsyncFriendly,
Ctx: ExecutionContext,
Unsigned: Representation<NBITS>,
{
fn starting_points(&self) -> impl Future<Output = ANNResult<Vec<u32>>> {
std::future::ready(self.provider.starting_points())
}
}
impl<'a, const NBITS: usize, V, D, Ctx> QuantAccessor<'a, NBITS, V, D, Ctx>
where
V: AsyncFriendly,
D: AsyncFriendly,
Ctx: ExecutionContext,
{
pub(crate) fn new(
provider: &'a DefaultProvider<V, SQStore<NBITS>, D, Ctx>,
is_search: bool,
) -> Self {
Self {
provider,
id_buffer: Vec::with_capacity(32),
is_search,
}
}
}
impl<const NBITS: usize, V, D, Ctx> Accessor for QuantAccessor<'_, NBITS, V, D, Ctx>
where
V: AsyncFriendly,
D: AsyncFriendly,
Ctx: ExecutionContext,
Unsigned: Representation<NBITS>,
{
type Element<'a>
= CVRef<'a, NBITS>
where
Self: 'a;
type ElementRef<'a> = CVRef<'a, NBITS>;
type GetError = ANNError;
fn get_element(
&mut self,
id: Self::Id,
) -> impl Future<Output = Result<Self::Element<'_>, Self::GetError>> + Send {
std::future::ready(
match self.provider.aux_vectors.get_vector(id.into_usize()) {
Ok(v) => Ok(v),
Err(err) => Err(err.into()),
},
)
}
fn on_elements_unordered<Itr, F>(
&mut self,
itr: Itr,
mut f: F,
) -> impl Future<Output = Result<(), Self::GetError>> + Send
where
Self: Sync,
Itr: Iterator<Item = Self::Id> + Send,
F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id),
{
let id_buffer = &mut self.id_buffer;
id_buffer.clear();
id_buffer.extend(itr);
let len = id_buffer.len();
let lookahead = self.provider.aux_vectors.prefetch_lookahead();
for id in id_buffer.iter().take(lookahead) {
self.provider.aux_vectors.prefetch_hint(id.into_usize());
}
for (i, id) in id_buffer.iter().enumerate() {
if lookahead > 0 && i + lookahead < len {
self.provider
.aux_vectors
.prefetch_hint(id_buffer[i + lookahead].into_usize());
}
let vector = match self.provider.aux_vectors.get_vector(id.into_usize()) {
Ok(v) => v,
Err(e) => return std::future::ready(Err(e.into())),
};
f(vector, *id)
}
std::future::ready(Ok(()))
}
}
impl<'a, const NBITS: usize, V, D, Ctx> DelegateNeighbor<'a> for QuantAccessor<'_, NBITS, V, D, Ctx>
where
V: AsyncFriendly,
D: AsyncFriendly,
Ctx: ExecutionContext,
{
type Delegate = &'a SimpleNeighborProviderAsync<u32>;
fn delegate_neighbor(&'a mut self) -> Self::Delegate {
self.provider.neighbors()
}
}
impl<const NBITS: usize, V, D, Ctx, T> BuildQueryComputer<&[T]>
for QuantAccessor<'_, NBITS, V, D, Ctx>
where
T: VectorRepr,
V: AsyncFriendly,
D: AsyncFriendly,
Ctx: ExecutionContext,
Unsigned: Representation<NBITS>,
QueryComputer<NBITS>: for<'a> PreprocessedDistanceFunction<CVRef<'a, NBITS>, f32>,
{
type QueryComputerError = ANNError;
type QueryComputer = QueryComputer<NBITS>;
fn build_query_computer(
&self,
from: &[T],
) -> Result<Self::QueryComputer, Self::QueryComputerError> {
Ok(self
.provider
.aux_vectors
.query_computer(from, self.is_search)?)
}
}
impl<const NBITS: usize, V, D, Ctx, T> ExpandBeam<&[T]> for QuantAccessor<'_, NBITS, V, D, Ctx>
where
T: VectorRepr,
V: AsyncFriendly,
D: AsyncFriendly,
Ctx: ExecutionContext,
Unsigned: Representation<NBITS>,
QueryComputer<NBITS>: for<'a> PreprocessedDistanceFunction<CVRef<'a, NBITS>, f32>,
{
}
impl<const NBITS: usize, V, D, Ctx> BuildDistanceComputer for QuantAccessor<'_, NBITS, V, D, Ctx>
where
V: AsyncFriendly,
D: AsyncFriendly,
Ctx: ExecutionContext,
Unsigned: Representation<NBITS>,
DistanceComputer: for<'a, 'b> DistanceFunction<CVRef<'a, NBITS>, CVRef<'b, NBITS>, f32>,
{
type DistanceComputerError = ANNError;
type DistanceComputer = DistanceComputer;
fn build_distance_computer(
&self,
) -> Result<Self::DistanceComputer, Self::DistanceComputerError> {
Ok(self.provider.aux_vectors.distance_computer()?)
}
}
impl<const NBITS: usize, V, D, Ctx> AsDeletionCheck for QuantAccessor<'_, NBITS, V, D, Ctx>
where
V: AsyncFriendly,
D: AsyncFriendly + DeletionCheck,
Ctx: ExecutionContext,
Unsigned: Representation<NBITS>,
{
type Checker = D;
fn as_deletion_check(&self) -> &D {
&self.provider.deleted
}
}
impl<const NBITS: usize, D, Ctx, T>
SearchStrategy<FullPrecisionProvider<T, SQStore<NBITS>, D, Ctx>, &[T]> for Quantized
where
T: VectorRepr,
D: AsyncFriendly + DeletionCheck,
Ctx: ExecutionContext,
Unsigned: Representation<NBITS>,
QueryComputer<NBITS>: for<'a> PreprocessedDistanceFunction<CVRef<'a, NBITS>, f32>,
{
type QueryComputer = QueryComputer<NBITS>;
type SearchAccessor<'a> = QuantAccessor<'a, NBITS, FullPrecisionStore<T>, D, Ctx>;
type SearchAccessorError = ANNError;
fn search_accessor<'a>(
&'a self,
provider: &'a FullPrecisionProvider<T, SQStore<NBITS>, D, Ctx>,
_context: &'a Ctx,
) -> Result<Self::SearchAccessor<'a>, Self::SearchAccessorError> {
Ok(QuantAccessor::new(provider, true))
}
}
impl<const NBITS: usize, D, Ctx, T>
DefaultPostProcessor<FullPrecisionProvider<T, SQStore<NBITS>, D, Ctx>, &[T]> for Quantized
where
T: VectorRepr,
D: AsyncFriendly + DeletionCheck,
Ctx: ExecutionContext,
Unsigned: Representation<NBITS>,
QueryComputer<NBITS>: for<'a> PreprocessedDistanceFunction<CVRef<'a, NBITS>, f32>,
{
default_post_processor!(Pipeline<FilterStartPoints, Rerank>);
}
impl<const NBITS: usize, D, Ctx, T>
SearchStrategy<DefaultProvider<NoStore, SQStore<NBITS>, D, Ctx>, &[T]> for Quantized
where
T: VectorRepr,
D: AsyncFriendly + DeletionCheck,
Ctx: ExecutionContext,
Unsigned: Representation<NBITS>,
QueryComputer<NBITS>: for<'a> PreprocessedDistanceFunction<CVRef<'a, NBITS>, f32>,
{
type QueryComputer = QueryComputer<NBITS>;
type SearchAccessor<'a> = QuantAccessor<'a, NBITS, NoStore, D, Ctx>;
type SearchAccessorError = ANNError;
fn search_accessor<'a>(
&'a self,
provider: &'a DefaultProvider<NoStore, SQStore<NBITS>, D, Ctx>,
_context: &'a Ctx,
) -> Result<Self::SearchAccessor<'a>, Self::SearchAccessorError> {
Ok(QuantAccessor::new(provider, true))
}
}
impl<const NBITS: usize, D, Ctx, T>
DefaultPostProcessor<DefaultProvider<NoStore, SQStore<NBITS>, D, Ctx>, &[T]> for Quantized
where
T: VectorRepr,
D: AsyncFriendly + DeletionCheck,
Ctx: ExecutionContext,
Unsigned: Representation<NBITS>,
QueryComputer<NBITS>: for<'a> PreprocessedDistanceFunction<CVRef<'a, NBITS>, f32>,
{
default_post_processor!(Pipeline<FilterStartPoints, RemoveDeletedIdsAndCopy>);
}
impl<const NBITS: usize, V, D, Ctx> PruneStrategy<DefaultProvider<V, SQStore<NBITS>, D, Ctx>>
for Quantized
where
V: AsyncFriendly,
D: AsyncFriendly + DeletionCheck,
Ctx: ExecutionContext,
Unsigned: Representation<NBITS>,
DistanceComputer: for<'a, 'b> DistanceFunction<CVRef<'a, NBITS>, CVRef<'b, NBITS>, f32>,
{
type DistanceComputer = DistanceComputer;
type PruneAccessor<'a> = QuantAccessor<'a, NBITS, V, D, Ctx>;
type PruneAccessorError = diskann::error::Infallible;
type WorkingSet = PassThrough;
fn create_working_set(&self, _capacity: usize) -> Self::WorkingSet {
PassThrough
}
fn prune_accessor<'a>(
&'a self,
provider: &'a DefaultProvider<V, SQStore<NBITS>, D, Ctx>,
_context: &'a Ctx,
) -> Result<Self::PruneAccessor<'a>, Self::PruneAccessorError> {
Ok(QuantAccessor::new(provider, false))
}
}
impl<const NBITS: usize, V, D, Ctx> workingset::Fill<PassThrough>
for QuantAccessor<'_, NBITS, V, D, Ctx>
where
V: AsyncFriendly,
D: AsyncFriendly + DeletionCheck,
Ctx: ExecutionContext,
Unsigned: Representation<NBITS>,
{
type Error = std::convert::Infallible;
type View<'a>
= &'a Self
where
Self: 'a;
async fn fill<'a, Itr>(
&'a mut self,
_state: &'a mut PassThrough,
_itr: Itr,
) -> Result<Self::View<'a>, Self::Error>
where
Itr: ExactSizeIterator<Item = Self::Id> + Clone + Send + Sync,
Self: 'a,
{
Ok(self)
}
}
impl<const NBITS: usize, V, D, Ctx> workingset::View<u32> for &QuantAccessor<'_, NBITS, V, D, Ctx>
where
V: AsyncFriendly,
D: AsyncFriendly + DeletionCheck,
Ctx: ExecutionContext,
Unsigned: Representation<NBITS>,
{
type ElementRef<'a> = CVRef<'a, NBITS>;
type Element<'a>
= CVRef<'a, NBITS>
where
Self: 'a;
fn get(&self, id: u32) -> Option<Self::Element<'_>> {
self.provider.aux_vectors.get_vector(id.into_usize()).ok()
}
}
impl<const NBITS: usize, V, D, Ctx, T>
InsertStrategy<DefaultProvider<V, SQStore<NBITS>, D, Ctx>, &[T]> for Quantized
where
T: VectorRepr,
V: AsyncFriendly,
D: AsyncFriendly,
D: AsyncFriendly + DeletionCheck,
Ctx: ExecutionContext,
Unsigned: Representation<NBITS>,
QueryComputer<NBITS>: for<'a> PreprocessedDistanceFunction<CVRef<'a, NBITS>, f32>,
DistanceComputer: for<'a, 'b> DistanceFunction<CVRef<'a, NBITS>, CVRef<'b, NBITS>, f32>,
Quantized: for<'a> SearchStrategy<DefaultProvider<V, SQStore<NBITS>, D, Ctx>, &'a [T]>,
{
type PruneStrategy = Self;
fn prune_strategy(&self) -> Self::PruneStrategy {
*self
}
}
impl<const NBITS: usize, V, D, Ctx, B>
glue::MultiInsertStrategy<DefaultProvider<V, SQStore<NBITS>, D, Ctx>, B> for Quantized
where
V: AsyncFriendly,
D: AsyncFriendly + DeletionCheck,
Ctx: ExecutionContext,
B: glue::Batch,
Self: PruneStrategy<DefaultProvider<V, SQStore<NBITS>, D, Ctx>, WorkingSet = PassThrough>
+ for<'a> InsertStrategy<
DefaultProvider<V, SQStore<NBITS>, D, Ctx>,
B::Element<'a>,
PruneStrategy = Self,
>,
{
type WorkingSet = PassThrough;
type Seed = PassThrough;
type FinishError = diskann::error::Infallible;
type InsertStrategy = Self;
fn insert_strategy(&self) -> Self::InsertStrategy {
*self
}
fn finish<Itr>(
&self,
_provider: &DefaultProvider<V, SQStore<NBITS>, D, Ctx>,
_ctx: &Ctx,
_batch: &std::sync::Arc<B>,
_ids: Itr,
) -> impl std::future::Future<Output = Result<Self::Seed, Self::FinishError>> + Send
where
Itr: ExactSizeIterator<Item = u32> + Send,
{
std::future::ready(Ok(PassThrough))
}
}
impl<const NBITS: usize> SaveWith<AsyncIndexMetadata> for SQStore<NBITS> {
type Ok = usize;
type Error = ANNError;
async fn save_with<P>(
&self,
write_provider: &P,
metadata: &AsyncIndexMetadata,
) -> Result<Self::Ok, Self::Error>
where
P: StorageWriteProvider,
{
let sq_storage = storage::SQStorage::new(metadata.prefix());
let bytes_written =
storage::bin::save_to_bin(self, write_provider, sq_storage.compressed_data_path())?;
let quantizer_bytes_written = sq_storage.save_quantizer(&self.quantizer, write_provider)?;
Ok(bytes_written + quantizer_bytes_written)
}
}
impl<const NBITS: usize> LoadWith<AsyncQuantLoadContext> for SQStore<NBITS>
where
Unsigned: Representation<NBITS>,
{
type Error = ANNError;
async fn load_with<P>(read_provider: &P, ctx: &AsyncQuantLoadContext) -> ANNResult<Self>
where
P: StorageReadProvider,
{
let sq_storage = storage::SQStorage::new(ctx.metadata.prefix());
let quantizer = sq_storage.load_quantizer(read_provider)?;
storage::bin::load_from_bin(
read_provider,
sq_storage.compressed_data_path(),
|num_points, _pq_bytes| {
Ok(SQStore::<NBITS>::new(
quantizer,
num_points,
ctx.metric,
ctx.prefetch_lookahead,
))
},
)
}
}
impl<const NBITS: usize> storage::bin::SetData for SQStore<NBITS>
where
Unsigned: Representation<NBITS>,
{
type Item = u8;
fn set_data(&mut self, i: usize, element: &[Self::Item]) -> ANNResult<()> {
unsafe { self.set_quant_vector(i, element) }
}
}
impl<const NBITS: usize> storage::bin::GetData for SQStore<NBITS> {
type Element = u8;
type Item<'a> = &'a [u8];
fn get_data(&self, i: usize) -> ANNResult<Self::Item<'_>> {
Ok(unsafe { self.data.get_slice(i) })
}
fn total(&self) -> usize {
self.data.max_vectors()
}
fn dim(&self) -> usize {
self.data.dim()
}
}
#[derive(Debug, Error)]
pub enum SQError {
#[error("Issue with canonical layout of data: {0:?}")]
CanonicalLayoutError(#[from] NotCanonical),
#[error("Input contains NaN values.")]
InputContainsNaN(#[from] InputContainsNaN),
#[error("Input full-precision conversion error : {0}")]
FullPrecisionConversionErr(String),
#[error("Mean Norm is missing in the quantizer.")]
MeanNormMissing(#[from] MeanNormMissing),
#[error("Unsupported distance metric: {0:?}")]
UnsupportedDistanceMetric(Metric),
#[error("Error while loading quantizer proto struct from file: {0:?}")]
ProtoStorageError(#[from] crate::storage::protos::ProtoStorageError),
#[error("Error while converting proto struct to Scalar Qunatizer: {0:?}")]
QuantizerDecodeError(#[from] crate::storage::protos::ProtoConversionError),
}
impl From<SQError> for ANNError {
#[cold]
fn from(err: SQError) -> Self {
ANNError::log_sq_error(err)
}
}
#[cfg(test)]
mod tests {
use crate::storage::VirtualStorageProvider;
use diskann::utils::ONE;
use diskann_quantization::scalar::train::ScalarQuantizationParameters;
use diskann_utils::views::MatrixView;
use diskann_vector::distance::Metric;
use rstest::rstest;
use super::*;
const NBITS: usize = 1;
const DIM: usize = 4;
const NPTS: usize = 5;
const DATA: [f32; 20] = [
0.286541, -0.079761, 0.373634, 0.878595, -0.131049, -0.131040, 0.883841, 0.429512,
-0.482576, 0.557701, -0.476350, -0.478727, 0.091383, -0.722600, -0.651460, -0.212363,
-0.510018, 0.158241, -0.457242, -0.711176,
];
const V: [f32; DIM] = [DATA[0], DATA[1], DATA[2], DATA[3]];
fn make_store(metric: Metric) -> SQStore<NBITS> {
let quantizer = ScalarQuantizationParameters::default()
.train(MatrixView::try_from(&DATA, NPTS, DIM).unwrap());
SQStore::new(quantizer, 5, metric, None)
}
#[test]
fn test_dim() {
let store = make_store(Metric::L2);
assert_eq!(store.dim(), DIM);
}
#[test]
fn test_set_and_get_vector() {
let store = make_store(Metric::L2);
store.set_vector(0, &V).unwrap();
store.get_vector(0).unwrap();
}
#[test]
#[should_panic]
fn test_set_vector_wrong_dim_panic_in_debug() {
let store = make_store(Metric::L2);
let _: Result<_, SQError> = store.set_vector(0, &[1.0f32; DIM + 1]);
}
#[test]
#[should_panic]
fn test_get_vector_oob() {
let store = make_store(Metric::L2);
let _: Result<_, SQError> = store.get_vector(NPTS);
}
#[test]
fn test_prefetch_hint_ok() {
let store = make_store(Metric::L2);
store.prefetch_hint(NPTS - 1);
}
#[test]
#[should_panic]
fn test_prefetch_hint_oob() {
let store = make_store(Metric::L2);
store.prefetch_hint(NPTS);
}
#[test]
fn test_distance_computer_variants() {
let dc_l2 = make_store(Metric::L2).distance_computer().unwrap();
match dc_l2 {
DistanceComputer::SquaredL2(_) => {}
_ => panic!("expected SquaredL2 variant"),
}
let dc_ip = make_store(Metric::InnerProduct)
.distance_computer()
.unwrap();
match dc_ip {
DistanceComputer::InnerProduct(_) => {}
_ => panic!("expected InnerProduct variant"),
}
let dc_cosine_normalized = make_store(Metric::CosineNormalized)
.distance_computer()
.unwrap();
match dc_cosine_normalized {
DistanceComputer::CosineNormalized(_) => {}
_ => panic!("expected CosineNormalized variant"),
}
let dc_unsupported = make_store(Metric::Cosine).distance_computer().unwrap_err();
match dc_unsupported {
SQError::UnsupportedDistanceMetric(Metric::Cosine) => {}
_ => panic!("expected UnsupportedDistanceMetric error"),
}
}
#[rstest]
fn test_query_computer(
#[values(Metric::L2, Metric::InnerProduct, Metric::CosineNormalized)] metric: Metric,
#[values(false, true)] allow_rescale: bool,
) {
let store = make_store(metric);
let q = [1.0_f32; DIM];
let result = store.query_computer(&q, allow_rescale);
assert!(
result.is_ok(),
"query_computer() failed for metric {:?} with allow_rescale={}",
metric,
allow_rescale
);
}
#[test]
fn test_set_quant_vector() {
let store = make_store(Metric::L2);
let compressed_vec_len = CVRef::<NBITS>::canonical_bytes(DIM);
let raw = vec![1u8; compressed_vec_len];
unsafe {
store.set_quant_vector(0, &raw).unwrap();
}
let slice = unsafe { store.data.get_slice(0) };
assert_eq!(slice, raw.as_slice());
}
#[test]
#[should_panic]
fn test_set_quant_vector_with_wrong_dim_panics() {
let store = make_store(Metric::L2);
let wrong_compressed_vec_len = CVRef::<NBITS>::canonical_bytes(DIM) + 1;
let raw = vec![1u8; wrong_compressed_vec_len];
unsafe {
store.set_quant_vector(0, &raw).unwrap();
}
}
#[rstest]
fn test_distance_computer_cosine_normalized(
#[values(Metric::L2, Metric::InnerProduct, Metric::CosineNormalized)] metric: Metric,
) {
let store = make_store(metric);
let v1 = [0.1, 0.2, 0.3, 0.4];
let v2 = [0.4, 0.3, 0.2, 0.1];
store.set_vector(0, &v1).unwrap();
store.set_vector(1, &v2).unwrap();
let dc = store.distance_computer().unwrap();
let x = store.get_vector(0).unwrap();
let y = store.get_vector(1).unwrap();
let _ = dc.evaluate_similarity(x, y);
}
#[tokio::test]
async fn test_save_with_and_load_with() {
let storage_provider = VirtualStorageProvider::new_memory();
let store = make_store(Metric::InnerProduct);
let prefix = "/test";
let metadata = AsyncIndexMetadata::new(prefix.to_string());
let bytes_written = store.save_with(&storage_provider, &metadata).await.unwrap();
let sq_storage = storage::SQStorage::new(prefix);
assert!(bytes_written > 0);
assert!(storage_provider.exists(sq_storage.compressed_data_path()),);
assert!(storage_provider.exists(sq_storage.quantizer_path()));
let ctx = AsyncQuantLoadContext {
metadata,
num_frozen_points: ONE,
metric: Metric::InnerProduct,
prefetch_lookahead: None,
is_disk_index: false,
prefetch_cache_line_level: None,
};
let loaded = SQStore::<NBITS>::load_with(&storage_provider, &ctx)
.await
.unwrap();
assert_eq!(loaded.dim(), store.dim());
for i in 0..NPTS {
let original = unsafe { store.data.get_slice(i) };
let loaded = unsafe { loaded.data.get_slice(i) };
assert_eq!(original, loaded);
}
}
}