use std::{future::Future, sync::Mutex};
use diskann::{
ANNError, ANNErrorKind, ANNResult, default_post_processor,
error::IntoANNResult,
graph::{
AdjacencyList,
glue::{
self, DefaultPostProcessor, FilterStartPoints, InsertStrategy, Pipeline, PruneStrategy,
SearchStrategy,
},
workingset,
},
provider::{ExecutionContext, HasId},
utils::{IntoUsize, VectorRepr},
};
use diskann_quantization::{
alloc::{GlobalAllocator, ScopedAllocator},
meta::NotCanonical,
spherical,
};
use diskann_utils::future::AsyncFriendly;
use diskann_vector::{PreprocessedDistanceFunction, distance::Metric};
use thiserror::Error;
use super::{GetFullPrecision, Rerank};
use crate::{
common::IgnoreLockPoison,
model::graph::provider::async_::{
FastMemoryVectorProviderAsync, SimpleNeighborProviderAsync,
common::{
AlignedMemoryVectorStore, CreateVectorStore, NoStore, SetElementHelper, TestCallCount,
VectorStore,
},
distances::UnwrapErr,
inmem::{DefaultProvider, FullPrecisionProvider, FullPrecisionStore},
postprocess::{AsDeletionCheck, DeletionCheck, RemoveDeletedIdsAndCopy},
},
utils::{Bridge, BridgeErr},
};
impl From<Bridge<QueryComputerError>> for ANNError {
#[track_caller]
fn from(err: Bridge<QueryComputerError>) -> Self {
ANNError::new(ANNErrorKind::SQError, err)
}
}
impl From<Bridge<diskann_quantization::spherical::CompressionError>> for ANNError {
#[track_caller]
fn from(err: Bridge<diskann_quantization::spherical::CompressionError>) -> Self {
ANNError::new(ANNErrorKind::SQError, err)
}
}
impl From<Bridge<diskann_quantization::spherical::iface::QueryDistanceError>> for ANNError {
#[track_caller]
fn from(err: Bridge<diskann_quantization::spherical::iface::QueryDistanceError>) -> Self {
ANNError::new(ANNErrorKind::SQError, err)
}
}
impl From<Bridge<spherical::UnsupportedMetric>> for ANNError {
#[track_caller]
fn from(err: Bridge<spherical::UnsupportedMetric>) -> Self {
ANNError::new(ANNErrorKind::SQError, err)
}
}
#[derive(Debug, Clone, Copy, Error)]
#[error(transparent)]
pub struct AllocatorError(#[from] diskann_quantization::alloc::AllocatorError);
impl From<AllocatorError> for ANNError {
#[track_caller]
fn from(err: AllocatorError) -> Self {
ANNError::new(ANNErrorKind::SQError, err)
}
}
#[derive(Debug, Error)]
pub enum QueryComputerError {
#[error("Quantizer computer error : {0}")]
QuantizerComputerError(#[from] spherical::iface::QueryComputerError),
#[error("Error in converting query to full precision : {0}")]
FullPrecisionConversionErr(#[from] Box<dyn std::error::Error + Send + Sync>),
}
const WRITE_LOCK_GRANULARITY: usize = 16;
const PREFETCH_DEFAULT: usize = 8;
pub struct SphericalStore {
data: AlignedMemoryVectorStore<u8>,
plan: Box<dyn spherical::iface::Quantizer + Send + Sync>,
write_locks: Vec<Mutex<()>>,
prefetch_lookahead: usize,
num_get_calls: TestCallCount,
}
impl SphericalStore {
pub(super) fn new<P>(plan: P, num_vectors: usize, prefetch_lookahead: Option<usize>) -> Self
where
P: spherical::iface::Quantizer + AsyncFriendly,
{
let write_locks = (0..num_vectors.div_ceil(WRITE_LOCK_GRANULARITY))
.map(|_| Mutex::new(()))
.collect::<Vec<_>>();
let bytes = plan.bytes();
Self {
data: AlignedMemoryVectorStore::with_capacity(num_vectors, bytes),
plan: Box::new(plan),
write_locks,
prefetch_lookahead: prefetch_lookahead.unwrap_or(PREFETCH_DEFAULT),
num_get_calls: TestCallCount::default(),
}
}
pub(crate) fn prefetch_hint(&self, i: usize) {
let data = unsafe { self.data.get_slice(i) };
diskann_vector::prefetch_hint_max::<4, _>(data);
}
#[cfg(test)]
pub(super) fn input_dim(&self) -> usize {
self.plan.dim()
}
pub fn bytes(&self) -> usize {
self.plan.bytes()
}
pub fn output_dim(&self) -> usize {
self.plan.dim()
}
pub(super) fn get_vector(&self, i: usize) -> Result<spherical::iface::Opaque<'_>, RQError> {
self.num_get_calls.increment();
let data = unsafe { self.data.get_slice(i) };
Ok(spherical::iface::Opaque::new(data))
}
pub(super) fn set_vector<T>(&self, i: usize, v: &[T]) -> Result<(), RQError>
where
T: VectorRepr,
{
let _guard = self.write_locks[i / WRITE_LOCK_GRANULARITY].lock_or_panic();
let data_mut = unsafe { self.data.get_mut_slice(i) };
let vf32 = T::as_f32(v).map_err(|e| RQError::FullPrecisionConversionErr(Box::new(e)))?;
self.plan.compress(
&vf32,
spherical::iface::OpaqueMut::new(data_mut),
ScopedAllocator::global(),
)?;
Ok(())
}
pub(super) fn distance_computer(
&self,
) -> Result<spherical::iface::DistanceComputer, AllocatorError> {
Ok(self.plan.distance_computer(GlobalAllocator)?)
}
pub(super) fn query_computer<T>(
&self,
query: &[T],
layout: spherical::iface::QueryLayout,
allow_rescale: bool,
) -> Result<spherical::iface::QueryComputer, QueryComputerError>
where
T: VectorRepr,
{
let qf32 = T::as_f32(query)
.map_err(|e| QueryComputerError::FullPrecisionConversionErr(Box::new(e)))?;
let computer = self.plan.fused_query_computer(
qf32.as_ref(),
layout,
allow_rescale,
GlobalAllocator,
ScopedAllocator::global(),
)?;
Ok(computer)
}
pub fn prefetch_lookahead(&self) -> usize {
self.prefetch_lookahead
}
}
impl VectorStore for SphericalStore {
fn total(&self) -> usize {
self.data.max_vectors()
}
fn count_for_get_vector(&self) -> usize {
self.num_get_calls.get()
}
}
macro_rules! create_vector_store {
($N:literal) => {
impl CreateVectorStore for spherical::iface::Impl<$N> {
type Target = SphericalStore;
fn create(
self,
max_points: usize,
metric: Metric,
prefetch_lookahead: Option<usize>,
) -> Self::Target {
assert_eq!(self.quantizer().metric(), metric, "mismatched metrics!");
SphericalStore::new(self, max_points, prefetch_lookahead)
}
}
};
($N:literal, $($Ns:literal),*) => {
create_vector_store!($N);
$(create_vector_store!($Ns);)*
};
}
create_vector_store!(1, 2, 4);
impl<T> SetElementHelper<T> for SphericalStore
where
T: VectorRepr,
{
fn set_element(&self, id: &u32, element: &[T]) -> ANNResult<()> {
self.set_vector(id.into_usize(), element)?;
Ok(())
}
}
type Distance = UnwrapErr<spherical::iface::DistanceComputer, spherical::iface::DistanceError>;
pub struct PruneAccessor<'a> {
store: &'a SphericalStore,
neighbors: &'a SimpleNeighborProviderAsync,
distance: Distance,
}
impl HasId for PruneAccessor<'_> {
type Id = u32;
}
impl glue::PruneAccessor for PruneAccessor<'_> {
type ElementRef<'a> = spherical::iface::Opaque<'a>;
type View<'a>
= &'a Self
where
Self: 'a;
type Distance<'a>
= &'a Distance
where
Self: 'a;
type Neighbors<'a>
= &'a SimpleNeighborProviderAsync
where
Self: 'a;
async fn fill<Itr>(&mut self, _itr: Itr) -> ANNResult<(Self::View<'_>, Self::Distance<'_>)>
where
Itr: ExactSizeIterator<Item = Self::Id> + Clone + Send + Sync,
{
Ok((self, &self.distance))
}
fn neighbors(&mut self) -> Self::Neighbors<'_> {
self.neighbors
}
}
impl workingset::View<u32> for &PruneAccessor<'_> {
type ElementRef<'a> = spherical::iface::Opaque<'a>;
type Element<'a>
= spherical::iface::Opaque<'a>
where
Self: 'a;
fn get(&self, id: u32) -> Option<Self::Element<'_>> {
self.store.get_vector(id.into_usize()).ok()
}
}
pub struct QuantAccessor<'a, V, D, Ctx> {
provider: &'a DefaultProvider<V, SphericalStore, D, Ctx>,
computer: spherical::iface::QueryComputer,
id_buffer: AdjacencyList<u32>,
}
impl<'a, V, D, Ctx> QuantAccessor<'a, V, D, Ctx>
where
V: AsyncFriendly,
D: AsyncFriendly,
Ctx: ExecutionContext,
{
pub(crate) fn new(
provider: &'a DefaultProvider<V, SphericalStore, D, Ctx>,
query: &[f32],
layout: spherical::iface::QueryLayout,
is_search: bool,
) -> ANNResult<Self> {
let computer = provider
.aux_vectors
.query_computer(query, layout, is_search)
.bridge_err()?;
Ok(Self {
provider,
computer,
id_buffer: AdjacencyList::with_capacity(32),
})
}
#[cfg(test)]
pub(crate) fn computer(&self) -> &spherical::iface::QueryComputer {
&self.computer
}
}
impl<T, D, Ctx> GetFullPrecision for QuantAccessor<'_, FullPrecisionStore<T>, D, Ctx>
where
T: VectorRepr,
{
type Repr = T;
fn as_full_precision(&self) -> &FastMemoryVectorProviderAsync<T> {
&self.provider.base_vectors
}
}
impl<V, D, Ctx> HasId for QuantAccessor<'_, V, D, Ctx> {
type Id = u32;
}
impl<V, D, Ctx> glue::SearchAccessor for QuantAccessor<'_, V, D, Ctx>
where
V: AsyncFriendly,
D: AsyncFriendly,
Ctx: ExecutionContext,
{
fn starting_points(&self) -> impl Future<Output = ANNResult<Vec<u32>>> {
std::future::ready(self.provider.starting_points())
}
fn num_starting_points(&self) -> impl Future<Output = ANNResult<usize>> {
std::future::ready(Ok(self.provider.num_start_points()))
}
fn start_point_distances<F>(
&mut self,
mut f: F,
) -> impl std::future::Future<Output = ANNResult<()>> + Send
where
F: FnMut(Self::Id, f32) + Send,
{
let mut f = move || -> ANNResult<()> {
for i in self.provider.starting_points()? {
let distance = self
.computer
.evaluate_similarity(self.provider.aux_vectors.get_vector(i.into_usize())?)
.bridge_err()?;
f(i, distance);
}
Ok(())
};
std::future::ready(f())
}
fn expand_beam<Itr, P, F>(
&mut self,
ids: Itr,
mut pred: P,
mut on_neighbors: F,
) -> impl std::future::Future<Output = ANNResult<()>> + Send
where
Itr: Iterator<Item = Self::Id> + Send,
P: glue::HybridPredicate<Self::Id> + Send + Sync,
F: FnMut(Self::Id, f32) + Send,
{
let f = move || -> ANNResult<()> {
let id_buffer = &mut self.id_buffer;
for n in ids {
self.provider
.neighbor_provider
.get_neighbors_sync(n.into_usize(), id_buffer)?;
id_buffer.retain(|i| pred.eval_mut(i));
let len = id_buffer.len();
let lookahead = self.provider.aux_vectors.prefetch_lookahead();
for id in id_buffer.iter().take(lookahead) {
self.provider.aux_vectors.prefetch_hint(id.into_usize());
}
for (i, id) in id_buffer.iter().enumerate() {
if lookahead > 0 && i + lookahead < len {
self.provider
.aux_vectors
.prefetch_hint(id_buffer[i + lookahead].into_usize());
}
let vector = self.provider.aux_vectors.get_vector(id.into_usize())?;
let distance = self.computer.evaluate_similarity(vector).bridge_err()?;
on_neighbors(*id, distance);
}
}
Ok(())
};
std::future::ready(f())
}
}
impl<V, D, Ctx> AsDeletionCheck for QuantAccessor<'_, V, D, Ctx>
where
V: AsyncFriendly,
D: AsyncFriendly + DeletionCheck,
Ctx: ExecutionContext,
{
type Checker = D;
fn as_deletion_check(&self) -> &D {
&self.provider.deleted
}
}
#[derive(Debug, Clone, Copy)]
pub struct Quantized {
layout: spherical::iface::QueryLayout,
is_search: bool,
}
impl Quantized {
pub fn build() -> Self {
Self {
layout: spherical::iface::QueryLayout::SameAsData,
is_search: false,
}
}
pub fn search(layout: spherical::iface::QueryLayout) -> Self {
Self {
layout,
is_search: true,
}
}
}
impl<'a, D, Ctx, T> SearchStrategy<'a, FullPrecisionProvider<T, SphericalStore, D, Ctx>, &'a [T]>
for Quantized
where
T: VectorRepr,
D: AsyncFriendly + DeletionCheck,
Ctx: ExecutionContext,
{
type SearchAccessor = QuantAccessor<'a, FullPrecisionStore<T>, D, Ctx>;
type SearchAccessorError = ANNError;
fn search_accessor(
&'a self,
provider: &'a FullPrecisionProvider<T, SphericalStore, D, Ctx>,
_context: &'a Ctx,
query: &'a [T],
) -> Result<Self::SearchAccessor, Self::SearchAccessorError> {
let as_f32 = T::as_f32(query).into_ann_result()?;
QuantAccessor::new(provider, &as_f32, self.layout, self.is_search)
}
}
impl<'a, D, Ctx, T>
DefaultPostProcessor<'a, FullPrecisionProvider<T, SphericalStore, D, Ctx>, &'a [T]>
for Quantized
where
T: VectorRepr,
D: AsyncFriendly + DeletionCheck,
Ctx: ExecutionContext,
{
default_post_processor!(Pipeline<FilterStartPoints, Rerank>);
}
impl<'a, D, Ctx, T> SearchStrategy<'a, DefaultProvider<NoStore, SphericalStore, D, Ctx>, &'a [T]>
for Quantized
where
T: VectorRepr,
D: AsyncFriendly + DeletionCheck,
Ctx: ExecutionContext,
{
type SearchAccessor = QuantAccessor<'a, NoStore, D, Ctx>;
type SearchAccessorError = ANNError;
fn search_accessor(
&'a self,
provider: &'a DefaultProvider<NoStore, SphericalStore, D, Ctx>,
_context: &'a Ctx,
query: &[T],
) -> Result<Self::SearchAccessor, Self::SearchAccessorError> {
let as_f32 = T::as_f32(query).into_ann_result()?;
QuantAccessor::new(provider, &as_f32, self.layout, self.is_search)
}
}
impl<'a, D, Ctx, T>
DefaultPostProcessor<'a, DefaultProvider<NoStore, SphericalStore, D, Ctx>, &'a [T]>
for Quantized
where
T: VectorRepr,
D: AsyncFriendly + DeletionCheck,
Ctx: ExecutionContext,
{
default_post_processor!(Pipeline<FilterStartPoints, RemoveDeletedIdsAndCopy>);
}
impl<V, D, Ctx> PruneStrategy<DefaultProvider<V, SphericalStore, D, Ctx>> for Quantized
where
V: AsyncFriendly,
D: AsyncFriendly + DeletionCheck,
Ctx: ExecutionContext,
{
type PruneAccessor<'a> = PruneAccessor<'a>;
type PruneAccessorError = AllocatorError;
fn prune_accessor<'a>(
&'a self,
provider: &'a DefaultProvider<V, SphericalStore, D, Ctx>,
_context: &'a Ctx,
_capacity: usize,
) -> Result<Self::PruneAccessor<'a>, Self::PruneAccessorError> {
let accessor = PruneAccessor {
store: &provider.aux_vectors,
neighbors: provider.neighbors(),
distance: provider
.aux_vectors
.distance_computer()
.map(UnwrapErr::new)?,
};
Ok(accessor)
}
}
impl<'a, V, D, Ctx, T> InsertStrategy<'a, DefaultProvider<V, SphericalStore, D, Ctx>, &'a [T]>
for Quantized
where
V: AsyncFriendly,
D: AsyncFriendly + DeletionCheck,
Ctx: ExecutionContext,
Quantized: SearchStrategy<'a, DefaultProvider<V, SphericalStore, D, Ctx>, &'a [T]>,
{
type PruneStrategy = Self;
fn prune_strategy(&self) -> Self::PruneStrategy {
*self
}
}
impl<V, D, Ctx, B> glue::MultiInsertStrategy<DefaultProvider<V, SphericalStore, D, Ctx>, B>
for Quantized
where
V: AsyncFriendly,
D: AsyncFriendly + DeletionCheck,
Ctx: ExecutionContext,
B: glue::Batch,
Self: for<'a> InsertStrategy<
'a,
DefaultProvider<V, SphericalStore, D, Ctx>,
B::Element<'a>,
PruneStrategy = Self,
>,
{
type Seed = ();
type FinishError = diskann::error::Infallible;
type PruneStrategy = Self;
type InsertStrategy = Self;
fn insert_strategy(&self) -> Self::InsertStrategy {
*self
}
fn finish<Itr>(
&self,
_provider: &DefaultProvider<V, SphericalStore, D, Ctx>,
_ctx: &Ctx,
_batch: &std::sync::Arc<B>,
_ids: Itr,
) -> impl std::future::Future<Output = Result<Self::Seed, Self::FinishError>> + Send
where
Itr: ExactSizeIterator<Item = u32> + Send,
{
std::future::ready(Ok(()))
}
fn seeded_prune_accessor<'a>(
&'a self,
provider: &'a DefaultProvider<V, SphericalStore, D, Ctx>,
context: &'a Ctx,
_seed: &'a (),
capacity: usize,
) -> ANNResult<PruneAccessor<'a>> {
Ok(self.prune_accessor(provider, context, capacity)?)
}
}
#[derive(Debug, Error)]
pub enum RQError {
#[error("Issue with canonical layout of data: {0:?}")]
CanonicalLayoutError(#[from] NotCanonical),
#[error("error during data compression")]
CompressionError(#[from] spherical::iface::CompressionError),
#[error("Error during full-precision conversion: {0}")]
FullPrecisionConversionErr(Box<dyn std::error::Error + Send + Sync>),
}
impl From<RQError> for ANNError {
#[cold]
#[track_caller]
fn from(err: RQError) -> Self {
ANNError::log_sq_error(err)
}
}
#[cfg(test)]
mod tests {
use diskann_quantization::{
alloc::GlobalAllocator,
spherical::{SphericalQuantizer, SupportedMetric},
};
use diskann_utils::views::{Matrix, MatrixView};
use diskann_vector::{
DistanceFunction, PreprocessedDistanceFunction, PureDistanceFunction,
distance::{InnerProduct, Metric, SquaredL2},
};
use rand::{SeedableRng, distr::Distribution, rngs::StdRng};
use rand_distr::StandardNormal;
use super::*;
fn make_store<const NBITS: usize>(
data: MatrixView<f32>,
metric: SupportedMetric,
rng: &mut StdRng,
) -> SphericalStore
where
spherical::iface::Impl<NBITS>:
spherical::iface::Quantizer + spherical::iface::Constructible,
{
let quantizer = SphericalQuantizer::train(
data,
diskann_quantization::algorithms::transforms::TransformKind::PaddingHadamard {
target_dim: diskann_quantization::algorithms::transforms::TargetDim::Natural,
},
metric,
diskann_quantization::spherical::PreScale::None,
rng,
GlobalAllocator,
)
.unwrap();
SphericalStore::new(
spherical::iface::Impl::<NBITS>::new(quantizer).unwrap(),
data.nrows(),
None,
)
}
fn dataset(nrows: usize, ncols: usize, rng: &mut StdRng) -> Matrix<f32> {
Matrix::new(
diskann_utils::views::Init(|| StandardNormal {}.sample(rng)),
nrows,
ncols,
)
}
#[test]
fn test_dim() {
let mut rng = StdRng::seed_from_u64(0x721e3de995bc908c);
let data = test_dataset();
let store = make_store::<1>(data.as_view(), Metric::L2.try_into().unwrap(), &mut rng);
assert_eq!(store.input_dim(), data.ncols());
}
#[test]
fn test_set_and_get_vector() {
let input_dim = 30;
let mut rng = StdRng::seed_from_u64(0x721e3de995bc908c);
let data = dataset(5, input_dim, &mut rng);
let store = make_store::<1>(data.as_view(), SupportedMetric::SquaredL2, &mut rng);
{
let v = store.get_vector(0).unwrap().into_inner();
assert_eq!(v.len(), store.plan.bytes());
for i in v.iter() {
assert_eq!(*i, 0)
}
}
store.set_vector(0, data.row(0)).unwrap();
{
let v = store.get_vector(0).unwrap().into_inner();
assert_eq!(v.len(), store.plan.bytes());
let mut nonzero_count = 0;
for i in v.iter() {
if *i != 0 {
nonzero_count += 1;
}
}
assert_eq!(
nonzero_count,
9 * v.len() / 10,
"Expected roughly 90% of the vectors to be nonzero. Instead, on {} of {} wer.",
nonzero_count,
v.len(),
);
}
}
#[test]
#[should_panic]
fn test_get_vector_oob() {
let input_dim = 5;
let mut rng = StdRng::seed_from_u64(0x5460899c60762fed);
let data = dataset(5, input_dim, &mut rng);
let store = make_store::<1>(data.as_view(), SupportedMetric::SquaredL2, &mut rng);
let _ = store.get_vector(data.nrows());
}
#[test]
fn test_prefetch_hint_ok() {
let input_dim = 5;
let mut rng = StdRng::seed_from_u64(0x40f06c4034892796);
let data = dataset(5, input_dim, &mut rng);
let store = make_store::<1>(data.as_view(), SupportedMetric::InnerProduct, &mut rng);
for i in 0..data.nrows() {
store.prefetch_hint(i)
}
}
#[test]
#[should_panic]
fn test_prefetch_hint_oob() {
let input_dim = 5;
let mut rng = StdRng::seed_from_u64(0x5abccaa184dfdb93);
let data = dataset(5, input_dim, &mut rng);
let store = make_store::<1>(data.as_view(), SupportedMetric::InnerProduct, &mut rng);
store.prefetch_hint(data.nrows());
}
fn relative_error(got: f32, expected: f32) -> f32 {
if expected == 0.0 {
return if got == 0.0 { 0.0 } else { f32::INFINITY };
}
(got - expected).abs() / expected.abs()
}
#[test]
fn test_distance_computer_variants() {
let input_dim = 128;
let mut rng = StdRng::seed_from_u64(0x5abccaa184dfdb93);
let data = dataset(10, input_dim, &mut rng);
{
let store = make_store::<1>(data.as_view(), SupportedMetric::SquaredL2, &mut rng);
let computer = store.distance_computer().unwrap();
let max_relative_error = 0.25;
for (i, r) in data.row_iter().enumerate() {
store.set_vector(i, r).unwrap();
}
for (i, a) in data.row_iter().enumerate() {
for (j, b) in data.row_iter().enumerate().skip(i + 1) {
let expected: f32 = SquaredL2::evaluate(a, b);
let got: f32 = computer
.evaluate_similarity(
store.get_vector(i).unwrap(),
store.get_vector(j).unwrap(),
)
.unwrap();
let err = relative_error(got, expected);
assert!(
err < max_relative_error,
"expected a relative error less than {}, instead found {} for\
expected = {}, got = {}",
max_relative_error,
err,
expected,
got
);
}
}
}
{
let store = make_store::<1>(data.as_view(), SupportedMetric::InnerProduct, &mut rng);
let computer = store.distance_computer().unwrap();
for (i, r) in data.row_iter().enumerate() {
store.set_vector(i, r).unwrap();
}
let mut signs_match = 0;
let mut total = 0;
for (i, a) in data.row_iter().enumerate() {
for (j, b) in data.row_iter().enumerate().skip(i + 1) {
total += 1;
let expected: f32 = InnerProduct::evaluate(a, b);
let got: f32 = computer
.evaluate_similarity(
store.get_vector(i).unwrap(),
store.get_vector(j).unwrap(),
)
.unwrap();
if expected.is_sign_negative() == got.is_sign_negative() {
signs_match += 1;
}
}
}
assert!(
signs_match > 8 * total / 10,
"expected 80% of the inner-product signs to match. Instead got {} of {}.",
signs_match,
total,
);
}
}
#[test]
fn test_query_computer_variants() {
let input_dim = 128;
let mut rng = StdRng::seed_from_u64(0x5abccaa184dfdb93);
let data = dataset(10, input_dim, &mut rng);
{
let store = make_store::<1>(data.as_view(), SupportedMetric::SquaredL2, &mut rng);
let max_relative_error = 0.2;
for (i, r) in data.row_iter().enumerate() {
store.set_vector(i, r).unwrap();
}
for (i, a) in data.row_iter().enumerate() {
let computer = store
.query_computer(a, spherical::iface::QueryLayout::FourBitTransposed, false)
.unwrap();
for (j, b) in data.row_iter().enumerate() {
if i == j {
continue;
}
let expected: f32 = SquaredL2::evaluate(a, b);
let got: f32 = computer
.evaluate_similarity(store.get_vector(j).unwrap())
.unwrap();
let err = relative_error(got, expected);
assert!(
err < max_relative_error,
"expected a relative error less than {}, instead found {}.\
expected = {}, got = {} for pair (i, j) = ({}, {})",
max_relative_error,
err,
expected,
got,
i,
j,
);
}
}
}
{
let store = make_store::<1>(data.as_view(), SupportedMetric::InnerProduct, &mut rng);
for (i, r) in data.row_iter().enumerate() {
store.set_vector(i, r).unwrap();
}
let mut signs_match = 0;
let mut total = 0;
for (i, a) in data.row_iter().enumerate() {
let computer = store
.query_computer(a, spherical::iface::QueryLayout::FourBitTransposed, true)
.unwrap();
for (j, b) in data.row_iter().enumerate() {
if i == j {
continue;
}
total += 1;
let expected: f32 = InnerProduct::evaluate(a, b);
let got: f32 = computer
.evaluate_similarity(store.get_vector(j).unwrap())
.unwrap();
if expected.is_sign_negative() == got.is_sign_negative() {
signs_match += 1;
}
}
}
assert!(
signs_match > 85 * total / 100,
"expected 85% of the inner-product signs to match. Instead got {} of {}.",
signs_match,
total,
);
}
}
#[test]
fn test_compression_errors() {
let input_dim = 10;
let mut rng = StdRng::seed_from_u64(0x721e3de995bc908c);
let data = dataset(5, input_dim, &mut rng);
let store = make_store::<1>(data.as_view(), SupportedMetric::SquaredL2, &mut rng);
let err = store
.set_vector(0, &(data.row(0)[..input_dim - 10]))
.unwrap_err();
assert!(matches!(err, RQError::CompressionError(..)));
}
fn test_dataset() -> Matrix<f32> {
let data = vec![
0.28657,
-0.0318168,
0.0666847,
0.0329265,
-0.00829283,
0.168735,
-0.000846311,
-0.360779, -0.0968938,
0.161921,
-0.0979579,
0.102228,
-0.259928,
-0.139634,
0.165384,
-0.293443, 0.130205,
0.265737,
0.401816,
-0.407552,
0.13012,
-0.0475244,
0.511723,
-0.4372, -0.0979126,
0.135861,
-0.0154144,
-0.14047,
-0.0250029,
-0.190279,
0.407283,
-0.389184, -0.264153,
0.0696822,
-0.145585,
0.370284,
0.186825,
-0.140736,
0.274703,
-0.334563, 0.247613,
0.513165,
-0.0845867,
0.0532264,
-0.00480601,
-0.122408,
0.47227,
-0.268301, 0.103198,
0.30756,
-0.316293,
-0.0686877,
-0.330729,
-0.461997,
0.550857,
-0.240851, 0.128258,
0.786291,
-0.0268103,
0.111763,
-0.308962,
-0.17407,
0.437154,
-0.159879, 0.00374063,
0.490301,
0.0327826,
-0.0340962,
-0.118605,
0.163879,
0.2737,
-0.299942, -0.284077,
0.249377,
-0.0307734,
-0.0661631,
0.233854,
0.427987,
0.614132,
-0.288649, -0.109492,
0.203939,
-0.73956,
-0.130748,
0.22072,
0.0647836,
0.328726,
-0.374602, -0.223114,
0.0243489,
0.109195,
-0.416914,
0.0201052,
-0.0190542,
0.947078,
-0.333229, -0.165869,
-0.00296729,
-0.414378,
0.231321,
0.205365,
0.161761,
0.148608,
-0.395063, -0.0498255,
0.193279,
-0.110946,
-0.181174,
-0.274578,
-0.227511,
0.190208,
-0.256174, -0.188106,
-0.0292958,
0.0930939,
0.0558456,
0.257437,
0.685481,
0.307922,
-0.320006, 0.250035,
0.275942,
-0.0856306,
-0.352027,
-0.103509,
-0.00890859,
0.276121,
-0.324718, ];
Matrix::try_from(data.into(), 16, 8).unwrap()
}
}