use std::marker::PhantomData;
use diskann_utils::{Reborrow, ReborrowMut};
use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction};
use diskann_wide::{
Architecture,
arch::{Scalar, Target1, Target2},
};
#[cfg(feature = "flatbuffers")]
use flatbuffers::FlatBufferBuilder;
use thiserror::Error;
#[cfg(target_arch = "x86_64")]
use diskann_wide::arch::x86_64::{V3, V4};
#[cfg(target_arch = "aarch64")]
use diskann_wide::arch::aarch64::Neon;
use super::{
CompensatedCosine, CompensatedIP, CompensatedSquaredL2, Data, DataMut, DataRef, FullQuery,
FullQueryMut, FullQueryRef, Query, QueryMut, QueryRef, SphericalQuantizer, SupportedMetric,
quantizer,
};
use crate::{
AsFunctor, CompressIntoWith,
alloc::{
Allocator, AllocatorCore, AllocatorError, GlobalAllocator, Poly, ScopedAllocator, TryClone,
},
bits::{self, Representation, Unsigned},
distances::{self, UnequalLengths},
error::InlineError,
meta,
num::PowerOfTwo,
poly,
};
#[cfg(feature = "flatbuffers")]
use crate::{alloc::CompoundError, flatbuffers as fb};
type Rf32 = distances::Result<f32>;
#[derive(Debug, Clone)]
pub struct QueryBufferDescription {
size: usize,
align: PowerOfTwo,
}
impl QueryBufferDescription {
pub fn new(size: usize, align: PowerOfTwo) -> Self {
Self { size, align }
}
pub fn bytes(&self) -> usize {
self.size
}
pub fn align(&self) -> PowerOfTwo {
self.align
}
}
pub trait Quantizer<A = GlobalAllocator>: Send + Sync
where
A: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
{
fn nbits(&self) -> usize;
fn bytes(&self) -> usize;
fn dim(&self) -> usize;
fn full_dim(&self) -> usize;
fn distance_computer(&self, allocator: A) -> Result<DistanceComputer<A>, AllocatorError>;
fn distance_computer_ref(&self) -> &dyn DynDistanceComputer;
fn query_computer(
&self,
layout: QueryLayout,
allocator: A,
) -> Result<DistanceComputer<A>, DistanceComputerError>;
fn query_buffer_description(
&self,
layout: QueryLayout,
) -> Result<QueryBufferDescription, UnsupportedQueryLayout>;
fn compress_query(
&self,
x: &[f32],
layout: QueryLayout,
allow_rescale: bool,
buffer: OpaqueMut<'_>,
scratch: ScopedAllocator<'_>,
) -> Result<(), QueryCompressionError>;
fn fused_query_computer(
&self,
x: &[f32],
layout: QueryLayout,
allow_rescale: bool,
allocator: A,
scratch: ScopedAllocator<'_>,
) -> Result<QueryComputer<A>, QueryComputerError>;
fn is_supported(&self, layout: QueryLayout) -> bool;
fn compress(
&self,
x: &[f32],
into: OpaqueMut<'_>,
scratch: ScopedAllocator<'_>,
) -> Result<(), CompressionError>;
fn metric(&self) -> SupportedMetric;
fn try_clone_into(&self, allocator: A) -> Result<Poly<dyn Quantizer<A>, A>, AllocatorError>;
crate::utils::features! {
#![feature = "flatbuffers"]
fn serialize(&self, allocator: A) -> Result<Poly<[u8], A>, AllocatorError>;
}
}
#[derive(Debug, Error)]
#[error("Layout {layout} is not supported for {desc}")]
pub struct UnsupportedQueryLayout {
layout: QueryLayout,
desc: &'static str,
}
impl UnsupportedQueryLayout {
fn new(layout: QueryLayout, desc: &'static str) -> Self {
Self { layout, desc }
}
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum DistanceComputerError {
#[error(transparent)]
UnsupportedQueryLayout(#[from] UnsupportedQueryLayout),
#[error(transparent)]
AllocatorError(#[from] AllocatorError),
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum QueryCompressionError {
#[error(transparent)]
UnsupportedQueryLayout(#[from] UnsupportedQueryLayout),
#[error(transparent)]
CompressionError(#[from] CompressionError),
#[error(transparent)]
NotCanonical(#[from] NotCanonical),
#[error(transparent)]
AllocatorError(#[from] AllocatorError),
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum QueryComputerError {
#[error(transparent)]
UnsupportedQueryLayout(#[from] UnsupportedQueryLayout),
#[error(transparent)]
CompressionError(#[from] CompressionError),
#[error(transparent)]
AllocatorError(#[from] AllocatorError),
}
#[derive(Debug, Error)]
#[error("Error occured during query compression")]
pub enum CompressionError {
NotCanonical(#[source] InlineError<16>),
CompressionError(#[source] quantizer::CompressionError),
}
impl CompressionError {
fn not_canonical<E>(error: E) -> Self
where
E: std::error::Error + Send + Sync + 'static,
{
Self::NotCanonical(InlineError::new(error))
}
}
#[derive(Debug, Error)]
#[error("An opaque argument did not have the required alignment or length")]
pub struct NotCanonical {
source: Box<dyn std::error::Error + Send + Sync>,
}
impl NotCanonical {
fn new<E>(err: E) -> Self
where
E: std::error::Error + Send + Sync + 'static,
{
Self {
source: Box::new(err),
}
}
}
#[derive(Debug, Clone, Copy)]
#[repr(transparent)]
pub struct Opaque<'a>(&'a [u8]);
impl<'a> Opaque<'a> {
pub fn new(slice: &'a [u8]) -> Self {
Self(slice)
}
pub fn into_inner(self) -> &'a [u8] {
self.0
}
}
impl std::ops::Deref for Opaque<'_> {
type Target = [u8];
fn deref(&self) -> &[u8] {
self.0
}
}
impl<'short> Reborrow<'short> for Opaque<'_> {
type Target = Opaque<'short>;
fn reborrow(&'short self) -> Self::Target {
*self
}
}
#[derive(Debug)]
#[repr(transparent)]
pub struct OpaqueMut<'a>(&'a mut [u8]);
impl<'a> OpaqueMut<'a> {
pub fn new(slice: &'a mut [u8]) -> Self {
Self(slice)
}
pub fn inspect(&mut self) -> &mut [u8] {
self.0
}
}
impl std::ops::Deref for OpaqueMut<'_> {
type Target = [u8];
fn deref(&self) -> &[u8] {
self.0
}
}
impl std::ops::DerefMut for OpaqueMut<'_> {
fn deref_mut(&mut self) -> &mut [u8] {
self.0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueryLayout {
SameAsData,
FourBitTransposed,
ScalarQuantized,
FullPrecision,
}
impl QueryLayout {
#[cfg(test)]
fn all() -> [Self; 4] {
[
Self::SameAsData,
Self::FourBitTransposed,
Self::ScalarQuantized,
Self::FullPrecision,
]
}
}
impl std::fmt::Display for QueryLayout {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
<Self as std::fmt::Debug>::fmt(self, fmt)
}
}
trait ReportQueryLayout {
fn report_query_layout(&self) -> QueryLayout;
}
impl<T, M, L, R> ReportQueryLayout for Reify<T, M, L, R>
where
T: ReportQueryLayout,
{
fn report_query_layout(&self) -> QueryLayout {
self.inner.report_query_layout()
}
}
impl<D, Q> ReportQueryLayout for Curried<D, Q>
where
Q: ReportQueryLayout,
{
fn report_query_layout(&self) -> QueryLayout {
self.query.report_query_layout()
}
}
impl<const NBITS: usize, A> ReportQueryLayout for Data<NBITS, A>
where
Unsigned: Representation<NBITS>,
A: AllocatorCore,
{
fn report_query_layout(&self) -> QueryLayout {
QueryLayout::SameAsData
}
}
impl<const NBITS: usize, A> ReportQueryLayout for Query<NBITS, bits::Dense, A>
where
Unsigned: Representation<NBITS>,
A: AllocatorCore,
{
fn report_query_layout(&self) -> QueryLayout {
QueryLayout::ScalarQuantized
}
}
impl<A> ReportQueryLayout for Query<4, bits::BitTranspose, A>
where
A: AllocatorCore,
{
fn report_query_layout(&self) -> QueryLayout {
QueryLayout::FourBitTransposed
}
}
impl<A> ReportQueryLayout for FullQuery<A>
where
A: AllocatorCore,
{
fn report_query_layout(&self) -> QueryLayout {
QueryLayout::FullPrecision
}
}
trait FromOpaque: 'static + Send + Sync {
type Target<'a>;
type Error: std::error::Error + Send + Sync + 'static;
fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error>;
}
#[derive(Debug, Default)]
pub(super) struct AsFull;
#[derive(Debug, Default)]
pub(super) struct AsData<const NBITS: usize>;
#[derive(Debug)]
pub(super) struct AsQuery<const NBITS: usize, Perm = bits::Dense> {
_marker: PhantomData<Perm>,
}
impl<const NBITS: usize, Perm> Default for AsQuery<NBITS, Perm> {
fn default() -> Self {
Self {
_marker: PhantomData,
}
}
}
impl FromOpaque for AsFull {
type Target<'a> = FullQueryRef<'a>;
type Error = meta::slice::NotCanonical;
fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error> {
Self::Target::from_canonical(query.into_inner(), dim)
}
}
impl ReportQueryLayout for AsFull {
fn report_query_layout(&self) -> QueryLayout {
QueryLayout::FullPrecision
}
}
impl<const NBITS: usize> FromOpaque for AsData<NBITS>
where
Unsigned: Representation<NBITS>,
{
type Target<'a> = DataRef<'a, NBITS>;
type Error = meta::NotCanonical;
fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error> {
Self::Target::from_canonical_back(query.into_inner(), dim)
}
}
impl<const NBITS: usize> ReportQueryLayout for AsData<NBITS> {
fn report_query_layout(&self) -> QueryLayout {
QueryLayout::SameAsData
}
}
impl<const NBITS: usize, Perm> FromOpaque for AsQuery<NBITS, Perm>
where
Unsigned: Representation<NBITS>,
Perm: bits::PermutationStrategy<NBITS> + Send + Sync + 'static,
{
type Target<'a> = QueryRef<'a, NBITS, Perm>;
type Error = meta::NotCanonical;
fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error> {
Self::Target::from_canonical_back(query.into_inner(), dim)
}
}
impl<const NBITS: usize> ReportQueryLayout for AsQuery<NBITS, bits::Dense> {
fn report_query_layout(&self) -> QueryLayout {
QueryLayout::ScalarQuantized
}
}
impl<const NBITS: usize> ReportQueryLayout for AsQuery<NBITS, bits::BitTranspose> {
fn report_query_layout(&self) -> QueryLayout {
QueryLayout::FourBitTransposed
}
}
pub(super) struct Reify<T, M, L, R> {
inner: T,
dim: usize,
arch: M,
_markers: PhantomData<(L, R)>,
}
impl<T, M, L, R> Reify<T, M, L, R> {
pub(super) fn new(inner: T, dim: usize, arch: M) -> Self {
Self {
inner,
dim,
arch,
_markers: PhantomData,
}
}
}
impl<M, T, R> DynQueryComputer for Reify<T, M, (), R>
where
M: Architecture,
R: FromOpaque,
T: ReportQueryLayout + Send + Sync,
for<'a> &'a T: Target1<M, Rf32, R::Target<'a>>,
{
fn evaluate(&self, x: Opaque<'_>) -> Result<f32, QueryDistanceError> {
self.arch.run2(
|this: &Self, x| {
let x = R::from_opaque(x, this.dim)
.map_err(|err| QueryDistanceError::XReify(InlineError::new(err)))?;
this.arch
.run1(&this.inner, x)
.map_err(QueryDistanceError::UnequalLengths)
},
self,
x,
)
}
fn layout(&self) -> QueryLayout {
self.inner.report_query_layout()
}
}
impl<T, M, Q, R> DynDistanceComputer for Reify<T, M, Q, R>
where
M: Architecture,
Q: FromOpaque + Default + ReportQueryLayout,
R: FromOpaque,
T: for<'a> Target2<M, Rf32, Q::Target<'a>, R::Target<'a>> + Copy + Send + Sync,
{
fn evaluate(&self, query: Opaque<'_>, x: Opaque<'_>) -> Result<f32, DistanceError> {
self.arch.run3(
|this: &Self, query, x| {
let query = Q::from_opaque(query, this.dim)
.map_err(|err| DistanceError::QueryReify(InlineError::<24>::new(err)))?;
let x = R::from_opaque(x, this.dim)
.map_err(|err| DistanceError::XReify(InlineError::<16>::new(err)))?;
this.arch
.run2_inline(this.inner, query, x)
.map_err(DistanceError::UnequalLengths)
},
self,
query,
x,
)
}
fn layout(&self) -> QueryLayout {
Q::default().report_query_layout()
}
}
#[derive(Debug, Error)]
pub enum QueryDistanceError {
#[error("trouble trying to reify the argument")]
XReify(#[source] InlineError<16>),
#[error("encountered while trying to compute distances")]
UnequalLengths(#[source] UnequalLengths),
}
pub trait DynQueryComputer: Send + Sync {
fn evaluate(&self, x: Opaque<'_>) -> Result<f32, QueryDistanceError>;
fn layout(&self) -> QueryLayout;
}
pub struct QueryComputer<A = GlobalAllocator>
where
A: AllocatorCore,
{
inner: Poly<dyn DynQueryComputer, A>,
}
impl<A> QueryComputer<A>
where
A: AllocatorCore,
{
fn new<T>(inner: T, allocator: A) -> Result<Self, AllocatorError>
where
T: DynQueryComputer + 'static,
{
let inner = Poly::new(inner, allocator)?;
Ok(Self {
inner: poly!(DynQueryComputer, inner),
})
}
pub fn layout(&self) -> QueryLayout {
self.inner.layout()
}
pub fn into_inner(self) -> Poly<dyn DynQueryComputer, A> {
self.inner
}
}
impl<A> std::fmt::Debug for QueryComputer<A>
where
A: AllocatorCore,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"dynamic fused query computer with layout \"{}\"",
self.layout()
)
}
}
impl<A> PreprocessedDistanceFunction<Opaque<'_>, Result<f32, QueryDistanceError>>
for QueryComputer<A>
where
A: AllocatorCore,
{
fn evaluate_similarity(&self, x: Opaque<'_>) -> Result<f32, QueryDistanceError> {
self.inner.evaluate(x)
}
}
pub(super) struct Curried<D, Q> {
inner: D,
query: Q,
}
impl<D, Q> Curried<D, Q> {
pub(super) fn new(inner: D, query: Q) -> Self {
Self { inner, query }
}
}
impl<A, D, Q, T, R> Target1<A, R, T> for &Curried<D, Q>
where
A: Architecture,
Q: for<'a> Reborrow<'a>,
D: for<'a> Target2<A, R, <Q as Reborrow<'a>>::Target, T> + Copy,
{
fn run(self, arch: A, x: T) -> R {
self.inner.run(arch, self.query.reborrow(), x)
}
}
#[derive(Debug, Error)]
pub enum DistanceError {
#[error("trouble trying to reify the left-hand argument")]
QueryReify(InlineError<24>),
#[error("trouble trying to reify the right-hand argument")]
XReify(InlineError<16>),
#[error("encountered while trying to compute distances")]
UnequalLengths(UnequalLengths),
}
pub trait DynDistanceComputer: Send + Sync {
fn evaluate(&self, query: Opaque<'_>, x: Opaque<'_>) -> Result<f32, DistanceError>;
fn layout(&self) -> QueryLayout;
}
pub struct DistanceComputer<A = GlobalAllocator>
where
A: AllocatorCore,
{
inner: Poly<dyn DynDistanceComputer, A>,
}
impl<A> DistanceComputer<A>
where
A: AllocatorCore,
{
pub(super) fn new<T>(inner: T, allocator: A) -> Result<Self, AllocatorError>
where
T: DynDistanceComputer + 'static,
{
let inner = Poly::new(inner, allocator)?;
Ok(Self {
inner: poly!(DynDistanceComputer, inner),
})
}
pub fn layout(&self) -> QueryLayout {
self.inner.layout()
}
pub fn into_inner(self) -> Poly<dyn DynDistanceComputer, A> {
self.inner
}
}
impl<A> std::fmt::Debug for DistanceComputer<A>
where
A: AllocatorCore,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"dynamic distance computer with layout \"{}\"",
self.layout()
)
}
}
impl<A> DistanceFunction<Opaque<'_>, Opaque<'_>, Result<f32, DistanceError>> for DistanceComputer<A>
where
A: AllocatorCore,
{
fn evaluate_similarity(&self, query: Opaque<'_>, x: Opaque<'_>) -> Result<f32, DistanceError> {
self.inner.evaluate(query, x)
}
}
#[cfg(all(not(test), feature = "flatbuffers"))]
const DEFAULT_SERIALIZED_BYTES: usize = 1024;
#[cfg(all(test, feature = "flatbuffers"))]
const DEFAULT_SERIALIZED_BYTES: usize = 1;
pub struct Impl<const NBITS: usize, A = GlobalAllocator>
where
A: Allocator,
{
quantizer: SphericalQuantizer<A>,
distance: Poly<dyn DynDistanceComputer, A>,
}
pub trait Constructible<A = GlobalAllocator>
where
A: Allocator,
{
fn dispatch_distance(
quantizer: &SphericalQuantizer<A>,
) -> Result<Poly<dyn DynDistanceComputer, A>, AllocatorError>;
}
impl<const NBITS: usize, A: Allocator> Constructible<A> for Impl<NBITS, A>
where
A: Allocator,
AsData<NBITS>: FromOpaque,
SphericalQuantizer<A>: Dispatchable<AsData<NBITS>, NBITS>,
{
fn dispatch_distance(
quantizer: &SphericalQuantizer<A>,
) -> Result<Poly<dyn DynDistanceComputer, A>, AllocatorError> {
diskann_wide::arch::dispatch2_no_features(
ComputerDispatcher::<AsData<NBITS>, NBITS>::new(),
quantizer,
quantizer.allocator().clone(),
)
.map(|obj| obj.inner)
}
}
impl<const NBITS: usize, A> TryClone for Impl<NBITS, A>
where
A: Allocator,
AsData<NBITS>: FromOpaque,
SphericalQuantizer<A>: Dispatchable<AsData<NBITS>, NBITS>,
{
fn try_clone(&self) -> Result<Self, AllocatorError> {
Self::new(self.quantizer.try_clone()?)
}
}
impl<const NBITS: usize, A: Allocator> Impl<NBITS, A> {
pub fn new(quantizer: SphericalQuantizer<A>) -> Result<Self, AllocatorError>
where
Self: Constructible<A>,
{
let distance = Self::dispatch_distance(&quantizer)?;
Ok(Self {
quantizer,
distance,
})
}
pub fn quantizer(&self) -> &SphericalQuantizer<A> {
&self.quantizer
}
pub fn supports(layout: QueryLayout) -> bool {
if const { NBITS == 1 } {
[
QueryLayout::SameAsData,
QueryLayout::FourBitTransposed,
QueryLayout::FullPrecision,
]
.contains(&layout)
} else {
[
QueryLayout::SameAsData,
QueryLayout::ScalarQuantized,
QueryLayout::FullPrecision,
]
.contains(&layout)
}
}
fn query_computer<Q, B>(&self, allocator: B) -> Result<DistanceComputer<B>, AllocatorError>
where
Q: FromOpaque,
B: AllocatorCore,
SphericalQuantizer<A>: Dispatchable<Q, NBITS>,
{
diskann_wide::arch::dispatch2_no_features(
ComputerDispatcher::<Q, NBITS>::new(),
&self.quantizer,
allocator,
)
}
fn compress_query<'a, T>(
&self,
query: &'a [f32],
storage: T,
scratch: ScopedAllocator<'a>,
) -> Result<(), QueryCompressionError>
where
SphericalQuantizer<A>: CompressIntoWith<&'a [f32], T, ScopedAllocator<'a>, Error = quantizer::CompressionError>,
{
self.quantizer
.compress_into_with(query, storage, scratch)
.map_err(|err| CompressionError::CompressionError(err).into())
}
fn fused_query_computer<Q, T, B>(
&self,
query: &[f32],
mut storage: T,
allocator: B,
scratch: ScopedAllocator<'_>,
) -> Result<QueryComputer<B>, QueryComputerError>
where
Q: FromOpaque,
T: for<'a> ReborrowMut<'a>
+ for<'a> Reborrow<'a, Target = Q::Target<'a>>
+ ReportQueryLayout
+ Send
+ Sync
+ 'static,
B: AllocatorCore,
SphericalQuantizer<A>: for<'a> CompressIntoWith<
&'a [f32],
<T as ReborrowMut<'a>>::Target,
ScopedAllocator<'a>,
Error = quantizer::CompressionError,
>,
SphericalQuantizer<A>: Dispatchable<Q, NBITS>,
{
if let Err(err) = self
.quantizer
.compress_into_with(query, storage.reborrow_mut(), scratch)
{
return Err(CompressionError::CompressionError(err).into());
}
diskann_wide::arch::dispatch3_no_features(
ComputerDispatcher::<Q, NBITS>::new(),
&self.quantizer,
storage,
allocator,
)
.map_err(|e| e.into())
}
#[cfg(feature = "flatbuffers")]
fn serialize<B>(&self, allocator: B) -> Result<Poly<[u8], B>, AllocatorError>
where
B: Allocator + std::panic::UnwindSafe,
A: std::panic::RefUnwindSafe,
{
let mut buf = FlatBufferBuilder::new_in(Poly::broadcast(
0u8,
DEFAULT_SERIALIZED_BYTES,
allocator.clone(),
)?);
let quantizer = &self.quantizer;
let (root, mut buf) = match std::panic::catch_unwind(move || {
let offset = quantizer.pack(&mut buf);
let root = fb::spherical::Quantizer::create(
&mut buf,
&fb::spherical::QuantizerArgs {
quantizer: Some(offset),
nbits: NBITS as u32,
},
);
(root, buf)
}) {
Ok(ret) => ret,
Err(err) => match err.downcast_ref::<String>() {
Some(msg) => {
if msg.contains("AllocatorError") {
return Err(AllocatorError);
} else {
std::panic::resume_unwind(err);
}
}
None => std::panic::resume_unwind(err),
},
};
fb::spherical::finish_quantizer_buffer(&mut buf, root);
Poly::from_iter(buf.finished_data().iter().copied(), allocator)
}
}
trait BuildComputer<M, Q, const N: usize>
where
M: Architecture,
Q: FromOpaque,
{
fn build_computer<A>(
&self,
arch: M,
allocator: A,
) -> Result<DistanceComputer<A>, AllocatorError>
where
A: AllocatorCore;
fn build_fused_computer<R, A>(
&self,
arch: M,
query: R,
allocator: A,
) -> Result<QueryComputer<A>, AllocatorError>
where
R: ReportQueryLayout + for<'a> Reborrow<'a, Target = Q::Target<'a>> + Send + Sync + 'static,
A: AllocatorCore;
}
fn identity<T>(x: T) -> T {
x
}
macro_rules! dispatch_map {
($N:literal, $Q:ty, $arch:ty) => {
dispatch_map!($N, $Q, $arch, identity);
};
($N:literal, $Q:ty, $arch:ty, $op:ident) => {
impl<A> BuildComputer<$arch, $Q, $N> for SphericalQuantizer<A>
where
A: Allocator,
{
fn build_computer<B>(
&self,
input_arch: $arch,
allocator: B,
) -> Result<DistanceComputer<B>, AllocatorError>
where
B: AllocatorCore,
{
type D = AsData<$N>;
let arch = ($op)(input_arch);
let dim = self.output_dim();
match self.metric() {
SupportedMetric::SquaredL2 => {
let reify = Reify::<CompensatedSquaredL2, _, $Q, D>::new(
self.as_functor(),
dim,
arch,
);
DistanceComputer::new(reify, allocator)
}
SupportedMetric::InnerProduct => {
let reify =
Reify::<CompensatedIP, _, $Q, D>::new(self.as_functor(), dim, arch);
DistanceComputer::new(reify, allocator)
}
SupportedMetric::Cosine => {
let reify =
Reify::<CompensatedCosine, _, $Q, D>::new(self.as_functor(), dim, arch);
DistanceComputer::new(reify, allocator)
}
}
}
fn build_fused_computer<R, B>(
&self,
input_arch: $arch,
query: R,
allocator: B,
) -> Result<QueryComputer<B>, AllocatorError>
where
R: ReportQueryLayout
+ for<'a> Reborrow<'a, Target = <$Q as FromOpaque>::Target<'a>>
+ Send
+ Sync
+ 'static,
B: AllocatorCore,
{
type D = AsData<$N>;
let arch = ($op)(input_arch);
let dim = self.output_dim();
match self.metric() {
SupportedMetric::SquaredL2 => {
let computer: CompensatedSquaredL2 = self.as_functor();
let curried = Curried::new(computer, query);
let reify = Reify::<_, _, (), D>::new(curried, dim, arch);
Ok(QueryComputer::new(reify, allocator)?)
}
SupportedMetric::InnerProduct => {
let computer: CompensatedIP = self.as_functor();
let curried = Curried::new(computer, query);
let reify = Reify::<_, _, (), D>::new(curried, dim, arch);
Ok(QueryComputer::new(reify, allocator)?)
}
SupportedMetric::Cosine => {
let computer: CompensatedCosine = self.as_functor();
let curried = Curried::new(computer, query);
let reify = Reify::<_, _, (), D>::new(curried, dim, arch);
Ok(QueryComputer::new(reify, allocator)?)
}
}
}
}
};
}
dispatch_map!(1, AsFull, Scalar);
dispatch_map!(2, AsFull, Scalar);
dispatch_map!(4, AsFull, Scalar);
dispatch_map!(8, AsFull, Scalar);
dispatch_map!(1, AsData<1>, Scalar);
dispatch_map!(2, AsData<2>, Scalar);
dispatch_map!(4, AsData<4>, Scalar);
dispatch_map!(8, AsData<8>, Scalar);
dispatch_map!(1, AsQuery<4, bits::BitTranspose>, Scalar);
dispatch_map!(2, AsQuery<2>, Scalar);
dispatch_map!(4, AsQuery<4>, Scalar);
dispatch_map!(8, AsQuery<8>, Scalar);
cfg_if::cfg_if! {
if #[cfg(target_arch = "x86_64")] {
fn downcast_to_v3(arch: V4) -> V3 {
arch.into()
}
dispatch_map!(1, AsFull, V3);
dispatch_map!(2, AsFull, V3);
dispatch_map!(4, AsFull, V3);
dispatch_map!(8, AsFull, V3);
dispatch_map!(1, AsData<1>, V3);
dispatch_map!(2, AsData<2>, V3);
dispatch_map!(4, AsData<4>, V3);
dispatch_map!(8, AsData<8>, V3);
dispatch_map!(1, AsQuery<4, bits::BitTranspose>, V3);
dispatch_map!(2, AsQuery<2>, V3);
dispatch_map!(4, AsQuery<4>, V3);
dispatch_map!(8, AsQuery<8>, V3);
dispatch_map!(1, AsFull, V4, downcast_to_v3);
dispatch_map!(2, AsFull, V4, downcast_to_v3);
dispatch_map!(4, AsFull, V4, downcast_to_v3);
dispatch_map!(8, AsFull, V4, downcast_to_v3);
dispatch_map!(1, AsData<1>, V4, downcast_to_v3);
dispatch_map!(2, AsData<2>, V4); dispatch_map!(4, AsData<4>, V4, downcast_to_v3);
dispatch_map!(8, AsData<8>, V4, downcast_to_v3);
dispatch_map!(1, AsQuery<4, bits::BitTranspose>, V4, downcast_to_v3);
dispatch_map!(2, AsQuery<2>, V4); dispatch_map!(4, AsQuery<4>, V4, downcast_to_v3);
dispatch_map!(8, AsQuery<8>, V4, downcast_to_v3);
} else if #[cfg(target_arch = "aarch64")] {
fn downcast(arch: Neon) -> Scalar {
arch.retarget()
}
dispatch_map!(1, AsFull, Neon, downcast);
dispatch_map!(2, AsFull, Neon, downcast);
dispatch_map!(4, AsFull, Neon, downcast);
dispatch_map!(8, AsFull, Neon, downcast);
dispatch_map!(1, AsData<1>, Neon, downcast);
dispatch_map!(2, AsData<2>, Neon, downcast);
dispatch_map!(4, AsData<4>, Neon, downcast);
dispatch_map!(8, AsData<8>, Neon, downcast);
dispatch_map!(1, AsQuery<4, bits::BitTranspose>, Neon, downcast);
dispatch_map!(2, AsQuery<2>, Neon, downcast);
dispatch_map!(4, AsQuery<4>, Neon, downcast);
dispatch_map!(8, AsQuery<8>, Neon, downcast);
}
}
#[derive(Debug, Clone, Copy)]
struct ComputerDispatcher<Q, const N: usize> {
_query_type: std::marker::PhantomData<Q>,
}
impl<Q, const N: usize> ComputerDispatcher<Q, N> {
fn new() -> Self {
Self {
_query_type: std::marker::PhantomData,
}
}
}
impl<M, const N: usize, A, B, Q>
diskann_wide::arch::Target2<
M,
Result<DistanceComputer<B>, AllocatorError>,
&SphericalQuantizer<A>,
B,
> for ComputerDispatcher<Q, N>
where
M: Architecture,
A: Allocator,
B: AllocatorCore,
Q: FromOpaque,
SphericalQuantizer<A>: BuildComputer<M, Q, N>,
{
fn run(
self,
arch: M,
quantizer: &SphericalQuantizer<A>,
allocator: B,
) -> Result<DistanceComputer<B>, AllocatorError> {
quantizer.build_computer(arch, allocator)
}
}
impl<M, const N: usize, A, R, B, Q>
diskann_wide::arch::Target3<
M,
Result<QueryComputer<B>, AllocatorError>,
&SphericalQuantizer<A>,
R,
B,
> for ComputerDispatcher<Q, N>
where
M: Architecture,
A: Allocator,
B: AllocatorCore,
Q: FromOpaque,
R: ReportQueryLayout + for<'a> Reborrow<'a, Target = Q::Target<'a>> + Send + Sync + 'static,
SphericalQuantizer<A>: BuildComputer<M, Q, N>,
{
fn run(
self,
arch: M,
quantizer: &SphericalQuantizer<A>,
query: R,
allocator: B,
) -> Result<QueryComputer<B>, AllocatorError> {
quantizer.build_fused_computer(arch, query, allocator)
}
}
#[cfg(target_arch = "x86_64")]
trait Dispatchable<Q, const N: usize>:
BuildComputer<Scalar, Q, N> + BuildComputer<V3, Q, N> + BuildComputer<V4, Q, N>
where
Q: FromOpaque,
{
}
#[cfg(target_arch = "x86_64")]
impl<Q, const N: usize, T> Dispatchable<Q, N> for T
where
Q: FromOpaque,
T: BuildComputer<Scalar, Q, N> + BuildComputer<V3, Q, N> + BuildComputer<V4, Q, N>,
{
}
#[cfg(target_arch = "aarch64")]
trait Dispatchable<Q, const N: usize>: BuildComputer<Scalar, Q, N> + BuildComputer<Neon, Q, N>
where
Q: FromOpaque,
{
}
#[cfg(target_arch = "aarch64")]
impl<Q, const N: usize, T> Dispatchable<Q, N> for T
where
Q: FromOpaque,
T: BuildComputer<Scalar, Q, N> + BuildComputer<Neon, Q, N>,
{
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
trait Dispatchable<Q, const N: usize>: BuildComputer<Scalar, Q, N>
where
Q: FromOpaque,
{
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
impl<Q, const N: usize, T> Dispatchable<Q, N> for T
where
Q: FromOpaque,
T: BuildComputer<Scalar, Q, N>,
{
}
impl<A, B> Quantizer<B> for Impl<1, A>
where
A: Allocator + std::panic::RefUnwindSafe + Send + Sync + 'static,
B: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
{
fn nbits(&self) -> usize {
1
}
fn dim(&self) -> usize {
self.quantizer.output_dim()
}
fn full_dim(&self) -> usize {
self.quantizer.input_dim()
}
fn bytes(&self) -> usize {
DataRef::<1>::canonical_bytes(self.quantizer.output_dim())
}
fn distance_computer(&self, allocator: B) -> Result<DistanceComputer<B>, AllocatorError> {
self.query_computer::<AsData<1>, _>(allocator)
}
fn distance_computer_ref(&self) -> &dyn DynDistanceComputer {
&*self.distance
}
fn query_computer(
&self,
layout: QueryLayout,
allocator: B,
) -> Result<DistanceComputer<B>, DistanceComputerError> {
match layout {
QueryLayout::SameAsData => Ok(self.query_computer::<AsData<1>, _>(allocator)?),
QueryLayout::FourBitTransposed => {
Ok(self.query_computer::<AsQuery<4, bits::BitTranspose>, _>(allocator)?)
}
QueryLayout::ScalarQuantized => {
Err(UnsupportedQueryLayout::new(layout, "1-bit compression").into())
}
QueryLayout::FullPrecision => Ok(self.query_computer::<AsFull, _>(allocator)?),
}
}
fn query_buffer_description(
&self,
layout: QueryLayout,
) -> Result<QueryBufferDescription, UnsupportedQueryLayout> {
let dim = <Self as Quantizer<B>>::dim(self);
match layout {
QueryLayout::SameAsData => Ok(QueryBufferDescription::new(
DataRef::<1>::canonical_bytes(dim),
PowerOfTwo::alignment_of::<u8>(),
)),
QueryLayout::FourBitTransposed => Ok(QueryBufferDescription::new(
QueryRef::<4, bits::BitTranspose>::canonical_bytes(dim),
PowerOfTwo::alignment_of::<u8>(),
)),
QueryLayout::ScalarQuantized => {
Err(UnsupportedQueryLayout::new(layout, "1-bit compression"))
}
QueryLayout::FullPrecision => Ok(QueryBufferDescription::new(
FullQueryRef::canonical_bytes(dim),
FullQueryRef::canonical_align(),
)),
}
}
fn compress_query(
&self,
x: &[f32],
layout: QueryLayout,
allow_rescale: bool,
mut buffer: OpaqueMut<'_>,
scratch: ScopedAllocator<'_>,
) -> Result<(), QueryCompressionError> {
let dim = <Self as Quantizer<B>>::dim(self);
let mut finish = |v: &[f32]| -> Result<(), QueryCompressionError> {
match layout {
QueryLayout::SameAsData => self.compress_query(
v,
DataMut::<1>::from_canonical_back_mut(&mut buffer, dim)
.map_err(NotCanonical::new)?,
scratch,
),
QueryLayout::FourBitTransposed => self.compress_query(
v,
QueryMut::<4, bits::BitTranspose>::from_canonical_back_mut(&mut buffer, dim)
.map_err(NotCanonical::new)?,
scratch,
),
QueryLayout::ScalarQuantized => {
Err(UnsupportedQueryLayout::new(layout, "1-bit compression").into())
}
QueryLayout::FullPrecision => self.compress_query(
v,
FullQueryMut::from_canonical_mut(&mut buffer, dim)
.map_err(NotCanonical::new)?,
scratch,
),
}
};
if allow_rescale && self.quantizer.metric() == SupportedMetric::InnerProduct {
let mut copy = x.to_owned();
self.quantizer.rescale(&mut copy);
finish(©)
} else {
finish(x)
}
}
fn fused_query_computer(
&self,
x: &[f32],
layout: QueryLayout,
allow_rescale: bool,
allocator: B,
scratch: ScopedAllocator<'_>,
) -> Result<QueryComputer<B>, QueryComputerError> {
let dim = <Self as Quantizer<B>>::dim(self);
let finish = |v: &[f32], allocator: B| -> Result<QueryComputer<B>, QueryComputerError> {
match layout {
QueryLayout::SameAsData => self.fused_query_computer::<AsData<1>, Data<1, _>, _>(
v,
Data::new_in(dim, allocator.clone())?,
allocator,
scratch,
),
QueryLayout::FourBitTransposed => self
.fused_query_computer::<AsQuery<4, bits::BitTranspose>, Query<4, bits::BitTranspose, _>, _>(
v,
Query::new_in(dim, allocator.clone())?,
allocator,
scratch,
),
QueryLayout::ScalarQuantized => {
Err(UnsupportedQueryLayout::new(layout, "1-bit compression").into())
}
QueryLayout::FullPrecision => self.fused_query_computer::<AsFull, FullQuery<_>, _>(
v,
FullQuery::empty(dim, allocator.clone())?,
allocator,
scratch,
),
}
};
if allow_rescale && self.quantizer.metric() == SupportedMetric::InnerProduct {
let mut copy = x.to_owned();
self.quantizer.rescale(&mut copy);
finish(©, allocator)
} else {
finish(x, allocator)
}
}
fn is_supported(&self, layout: QueryLayout) -> bool {
Self::supports(layout)
}
fn compress(
&self,
x: &[f32],
mut into: OpaqueMut<'_>,
scratch: ScopedAllocator<'_>,
) -> Result<(), CompressionError> {
let dim = <Self as Quantizer<B>>::dim(self);
let into = DataMut::<1>::from_canonical_back_mut(into.inspect(), dim)
.map_err(CompressionError::not_canonical)?;
self.quantizer
.compress_into_with(x, into, scratch)
.map_err(CompressionError::CompressionError)
}
fn metric(&self) -> SupportedMetric {
self.quantizer.metric()
}
fn try_clone_into(&self, allocator: B) -> Result<Poly<dyn Quantizer<B>, B>, AllocatorError> {
let clone = (*self).try_clone()?;
poly!({ Quantizer<B> }, clone, allocator)
}
#[cfg(feature = "flatbuffers")]
fn serialize(&self, allocator: B) -> Result<Poly<[u8], B>, AllocatorError> {
Impl::<1, A>::serialize(self, allocator)
}
}
macro_rules! plan {
($N:literal) => {
impl<A, B> Quantizer<B> for Impl<$N, A>
where
A: Allocator + std::panic::RefUnwindSafe + Send + Sync + 'static,
B: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
{
fn nbits(&self) -> usize {
$N
}
fn dim(&self) -> usize {
self.quantizer.output_dim()
}
fn full_dim(&self) -> usize {
self.quantizer.input_dim()
}
fn bytes(&self) -> usize {
DataRef::<$N>::canonical_bytes(<Self as Quantizer<B>>::dim(self))
}
fn distance_computer(
&self,
allocator: B
) -> Result<DistanceComputer<B>, AllocatorError> {
self.query_computer::<AsData<$N>, _>(allocator)
}
fn distance_computer_ref(&self) -> &dyn DynDistanceComputer {
&*self.distance
}
fn query_computer(
&self,
layout: QueryLayout,
allocator: B,
) -> Result<DistanceComputer<B>, DistanceComputerError> {
match layout {
QueryLayout::SameAsData => Ok(self.query_computer::<AsData<$N>, _>(allocator)?)
,
QueryLayout::FourBitTransposed => Err(UnsupportedQueryLayout::new(
layout,
concat!($N, "-bit compression"),
).into()),
QueryLayout::ScalarQuantized => {
Ok(self.query_computer::<AsQuery<$N, bits::Dense>, _>(allocator)?)
},
QueryLayout::FullPrecision => Ok(self.query_computer::<AsFull, _>(allocator)?),
}
}
fn query_buffer_description(
&self,
layout: QueryLayout
) -> Result<QueryBufferDescription, UnsupportedQueryLayout>
{
let dim = <Self as Quantizer<B>>::dim(self);
match layout {
QueryLayout::SameAsData => Ok(QueryBufferDescription::new(
DataRef::<$N>::canonical_bytes(dim),
PowerOfTwo::alignment_of::<u8>(),
)),
QueryLayout::FourBitTransposed => Err(UnsupportedQueryLayout {
layout,
desc: concat!($N, "-bit compression"),
}),
QueryLayout::ScalarQuantized => Ok(QueryBufferDescription::new(
QueryRef::<$N, bits::Dense>::canonical_bytes(dim),
PowerOfTwo::alignment_of::<u8>(),
)),
QueryLayout::FullPrecision => Ok(QueryBufferDescription::new(
FullQueryRef::canonical_bytes(dim),
FullQueryRef::canonical_align(),
)),
}
}
fn compress_query(
&self,
x: &[f32],
layout: QueryLayout,
allow_rescale: bool,
mut buffer: OpaqueMut<'_>,
scratch: ScopedAllocator<'_>,
) -> Result<(), QueryCompressionError> {
let dim = <Self as Quantizer<B>>::dim(self);
let mut finish = |v: &[f32]| -> Result<(), QueryCompressionError> {
match layout {
QueryLayout::SameAsData => self.compress_query(
v,
DataMut::<$N>::from_canonical_back_mut(
&mut buffer,
dim,
).map_err(NotCanonical::new)?,
scratch,
),
QueryLayout::FourBitTransposed => {
Err(UnsupportedQueryLayout::new(
layout,
concat!($N, "-bit compression"),
).into())
},
QueryLayout::ScalarQuantized => self.compress_query(
v,
QueryMut::<$N, bits::Dense>::from_canonical_back_mut(
&mut buffer,
dim,
).map_err(NotCanonical::new)?,
scratch,
),
QueryLayout::FullPrecision => self.compress_query(
v,
FullQueryMut::from_canonical_mut(
&mut buffer,
dim,
).map_err(NotCanonical::new)?,
scratch,
),
}
};
if allow_rescale && self.quantizer.metric() == SupportedMetric::InnerProduct {
let mut copy = x.to_owned();
self.quantizer.rescale(&mut copy);
finish(©)
} else {
finish(x)
}
}
fn fused_query_computer(
&self,
x: &[f32],
layout: QueryLayout,
allow_rescale: bool,
allocator: B,
scratch: ScopedAllocator<'_>,
) -> Result<QueryComputer<B>, QueryComputerError>
{
let dim = <Self as Quantizer<B>>::dim(self);
let finish = |v: &[f32]| -> Result<QueryComputer<B>, QueryComputerError> {
match layout {
QueryLayout::SameAsData => {
self.fused_query_computer::<AsData<$N>, Data<$N, _>, B>(
v,
Data::new_in(dim, allocator.clone())?,
allocator,
scratch,
)
},
QueryLayout::FourBitTransposed => {
Err(UnsupportedQueryLayout::new(
layout,
concat!($N, "-bit compression"),
).into())
},
QueryLayout::ScalarQuantized => {
self.fused_query_computer::<AsQuery<$N, bits::Dense>, Query<$N, bits::Dense, _>, B>(
v,
Query::new_in(dim, allocator.clone())?,
allocator,
scratch,
)
},
QueryLayout::FullPrecision => {
self.fused_query_computer::<AsFull, FullQuery<_>, B>(
v,
FullQuery::empty(dim, allocator.clone())?,
allocator,
scratch,
)
},
}
};
let metric = <Self as Quantizer<B>>::metric(self);
if allow_rescale && metric == SupportedMetric::InnerProduct {
let mut copy = x.to_owned();
self.quantizer.rescale(&mut copy);
finish(©)
} else {
finish(x)
}
}
fn is_supported(&self, layout: QueryLayout) -> bool {
Self::supports(layout)
}
fn compress(
&self,
x: &[f32],
mut into: OpaqueMut<'_>,
scratch: ScopedAllocator<'_>,
) -> Result<(), CompressionError> {
let dim = <Self as Quantizer<B>>::dim(self);
let into = DataMut::<$N>::from_canonical_back_mut(into.inspect(), dim)
.map_err(CompressionError::not_canonical)?;
self.quantizer.compress_into_with(x, into, scratch)
.map_err(CompressionError::CompressionError)
}
fn metric(&self) -> SupportedMetric {
self.quantizer.metric()
}
fn try_clone_into(&self, allocator: B) -> Result<Poly<dyn Quantizer<B>, B>, AllocatorError> {
let clone = (&*self).try_clone()?;
poly!({ Quantizer<B> }, clone, allocator)
}
#[cfg(feature = "flatbuffers")]
fn serialize(&self, allocator: B) -> Result<Poly<[u8], B>, AllocatorError> {
Impl::<$N, A>::serialize(self, allocator)
}
}
};
($N:literal, $($Ns:literal),*) => {
plan!($N);
$(plan!($Ns);)*
}
}
plan!(2, 4, 8);
#[cfg(feature = "flatbuffers")]
#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
#[derive(Debug, Clone, Error)]
#[non_exhaustive]
pub enum DeserializationError {
#[error("unhandled file identifier in flatbuffer")]
InvalidIdentifier,
#[error("unsupported number of bits ({0})")]
UnsupportedBitWidth(u32),
#[error(transparent)]
InvalidQuantizer(#[from] super::quantizer::DeserializationError),
#[error(transparent)]
InvalidFlatBuffer(#[from] flatbuffers::InvalidFlatbuffer),
#[error(transparent)]
AllocatorError(#[from] AllocatorError),
}
#[cfg(feature = "flatbuffers")]
#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
pub fn try_deserialize<O, A>(
data: &[u8],
alloc: A,
) -> Result<Poly<dyn Quantizer<O>, A>, DeserializationError>
where
O: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
A: Allocator + std::panic::RefUnwindSafe + Send + Sync + 'static,
{
fn unpack_bits<'a, const NBITS: usize, O, A>(
proto: fb::spherical::SphericalQuantizer<'_>,
alloc: A,
) -> Result<Poly<dyn Quantizer<O> + 'a, A>, DeserializationError>
where
O: Allocator + Send + Sync + std::panic::UnwindSafe + 'static,
A: Allocator + Send + Sync + 'a,
Impl<NBITS, A>: Quantizer<O> + Constructible<A>,
{
let imp = match Poly::new_with(
#[inline(never)]
|alloc| -> Result<_, super::quantizer::DeserializationError> {
let quantizer = SphericalQuantizer::try_unpack(alloc, proto)?;
Ok(Impl::new(quantizer)?)
},
alloc,
) {
Ok(imp) => imp,
Err(CompoundError::Allocator(err)) => {
return Err(err.into());
}
Err(CompoundError::Constructor(err)) => {
return Err(err.into());
}
};
Ok(poly!({ Quantizer<O> }, imp))
}
if !fb::spherical::quantizer_buffer_has_identifier(data) {
return Err(DeserializationError::InvalidIdentifier);
}
let root = fb::spherical::root_as_quantizer(data)?;
let nbits = root.nbits();
let proto = root.quantizer();
match nbits {
1 => unpack_bits::<1, _, _>(proto, alloc),
2 => unpack_bits::<2, _, _>(proto, alloc),
4 => unpack_bits::<4, _, _>(proto, alloc),
8 => unpack_bits::<8, _, _>(proto, alloc),
n => Err(DeserializationError::UnsupportedBitWidth(n)),
}
}
#[cfg(test)]
mod tests {
use diskann_utils::views::{Matrix, MatrixView};
use rand::{SeedableRng, rngs::StdRng};
use super::*;
use crate::{
algorithms::{TransformKind, transforms::TargetDim},
alloc::{AlignedAllocator, GlobalAllocator, Poly},
num::PowerOfTwo,
spherical::PreScale,
};
fn test_plan_1_bit(plan: &dyn Quantizer) {
assert_eq!(
plan.nbits(),
1,
"this test only applies to 1-bit quantization"
);
for layout in QueryLayout::all() {
match layout {
QueryLayout::SameAsData
| QueryLayout::FourBitTransposed
| QueryLayout::FullPrecision => assert!(
plan.is_supported(layout),
"expected {} to be supported",
layout
),
QueryLayout::ScalarQuantized => assert!(
!plan.is_supported(layout),
"expected {} to not be supported",
layout
),
}
}
}
fn test_plan_n_bit(plan: &dyn Quantizer, nbits: usize) {
assert_ne!(nbits, 1, "there is another test for 1-bit quantizers");
assert_eq!(
plan.nbits(),
nbits,
"this test only applies to 1-bit quantization"
);
for layout in QueryLayout::all() {
match layout {
QueryLayout::SameAsData
| QueryLayout::ScalarQuantized
| QueryLayout::FullPrecision => assert!(
plan.is_supported(layout),
"expected {} to be supported",
layout
),
QueryLayout::FourBitTransposed => assert!(
!plan.is_supported(layout),
"expected {} to not be supported",
layout
),
}
}
}
#[inline(never)]
fn test_plan(plan: &dyn Quantizer, nbits: usize, dataset: MatrixView<f32>) {
if nbits == 1 {
test_plan_1_bit(plan);
} else {
test_plan_n_bit(plan, nbits);
}
assert_eq!(plan.full_dim(), dataset.ncols());
let alloc = AlignedAllocator::new(PowerOfTwo::new(4).unwrap());
let mut a = Poly::broadcast(u8::default(), plan.bytes(), alloc).unwrap();
let mut b = Poly::broadcast(u8::default(), plan.bytes(), alloc).unwrap();
let scoped_global = ScopedAllocator::global();
plan.compress(dataset.row(0), OpaqueMut::new(&mut a), scoped_global)
.unwrap();
plan.compress(dataset.row(1), OpaqueMut::new(&mut b), scoped_global)
.unwrap();
let f = plan.distance_computer(GlobalAllocator).unwrap();
let _: f32 = f
.evaluate_similarity(Opaque::new(&a), Opaque::new(&b))
.unwrap();
let test_errors = |f: &dyn DynDistanceComputer| {
let err = f
.evaluate(Opaque::new(&a[..a.len() - 1]), Opaque::new(&b))
.unwrap_err();
assert!(matches!(err, DistanceError::QueryReify(_)));
let err = f
.evaluate(Opaque::new(&vec![0u8; a.len() + 1]), Opaque::new(&b))
.unwrap_err();
assert!(matches!(err, DistanceError::QueryReify(_)));
let err = f
.evaluate(Opaque::new(&a), Opaque::new(&b[..b.len() - 1]))
.unwrap_err();
assert!(matches!(err, DistanceError::XReify(_)));
let err = f
.evaluate(Opaque::new(&a), Opaque::new(&vec![0u8; b.len() + 1]))
.unwrap_err();
assert!(matches!(err, DistanceError::XReify(_)));
};
test_errors(&*f.inner);
let f = plan.distance_computer_ref();
let _: f32 = f.evaluate(Opaque::new(&a), Opaque::new(&b)).unwrap();
test_errors(f);
for layout in QueryLayout::all() {
if !plan.is_supported(layout) {
let check_message = |msg: &str| {
assert!(
msg.contains(&(layout.to_string())),
"error message ({}) should contain the layout \"{}\"",
msg,
layout
);
assert!(
msg.contains(&format!("{}", nbits)),
"error message ({}) should contain the number of bits \"{}\"",
msg,
nbits
);
};
{
let err = plan
.fused_query_computer(
dataset.row(1),
layout,
false,
GlobalAllocator,
scoped_global,
)
.unwrap_err();
let msg = err.to_string();
check_message(&msg);
}
{
let err = plan.query_buffer_description(layout).unwrap_err();
let msg = err.to_string();
check_message(&msg);
}
{
let buffer = &mut [];
let err = plan
.compress_query(
dataset.row(1),
layout,
true,
OpaqueMut::new(buffer),
scoped_global,
)
.unwrap_err();
let msg = err.to_string();
check_message(&msg);
}
{
let err = plan.query_computer(layout, GlobalAllocator).unwrap_err();
let msg = err.to_string();
check_message(&msg);
}
continue;
}
let g = plan
.fused_query_computer(
dataset.row(1),
layout,
false,
GlobalAllocator,
scoped_global,
)
.unwrap();
assert_eq!(
g.layout(),
layout,
"the query computer should faithfully preserve the requested layout"
);
let direct: f32 = g.evaluate_similarity(Opaque(&a)).unwrap();
{
let err = g
.evaluate_similarity(Opaque::new(&a[..a.len() - 1]))
.unwrap_err();
assert!(matches!(err, QueryDistanceError::XReify(_)));
let err = g
.evaluate_similarity(Opaque::new(&vec![0u8; a.len() + 1]))
.unwrap_err();
assert!(matches!(err, QueryDistanceError::XReify(_)));
}
let sizes = plan.query_buffer_description(layout).unwrap();
let mut buf =
Poly::broadcast(0u8, sizes.bytes(), AlignedAllocator::new(sizes.align())).unwrap();
plan.compress_query(
dataset.row(1),
layout,
false,
OpaqueMut::new(&mut buf),
scoped_global,
)
.unwrap();
let standalone = plan.query_computer(layout, GlobalAllocator).unwrap();
assert_eq!(
standalone.layout(),
layout,
"the standalone computer did not preserve the requested layout",
);
let indirect: f32 = standalone
.evaluate_similarity(Opaque(&buf), Opaque(&a))
.unwrap();
assert_eq!(
direct, indirect,
"the two different query computation APIs did not return the same result"
);
let too_small = &dataset.row(0)[..dataset.ncols() - 1];
assert!(
plan.fused_query_computer(too_small, layout, false, GlobalAllocator, scoped_global)
.is_err()
);
}
{
let mut too_small = vec![u8::default(); plan.bytes() - 1];
assert!(
plan.compress(dataset.row(0), OpaqueMut(&mut too_small), scoped_global)
.is_err()
);
let mut too_big = vec![u8::default(); plan.bytes() + 1];
assert!(
plan.compress(dataset.row(0), OpaqueMut(&mut too_big), scoped_global)
.is_err()
);
let mut just_right = vec![u8::default(); plan.bytes()];
assert!(
plan.compress(
&dataset.row(0)[..dataset.ncols() - 1],
OpaqueMut(&mut just_right),
scoped_global
)
.is_err()
);
}
}
fn make_impl<const NBITS: usize>(metric: SupportedMetric) -> (Impl<NBITS>, Matrix<f32>)
where
Impl<NBITS>: Constructible,
{
let data = test_dataset();
let mut rng = StdRng::seed_from_u64(0x7d535118722ff197);
let quantizer = SphericalQuantizer::train(
data.as_view(),
TransformKind::PaddingHadamard {
target_dim: TargetDim::Natural,
},
metric,
PreScale::None,
&mut rng,
GlobalAllocator,
)
.unwrap();
(Impl::<NBITS>::new(quantizer).unwrap(), data)
}
#[test]
fn test_plan_1bit_l2() {
let (plan, data) = make_impl::<1>(SupportedMetric::SquaredL2);
test_plan(&plan, 1, data.as_view());
}
#[test]
fn test_plan_1bit_ip() {
let (plan, data) = make_impl::<1>(SupportedMetric::InnerProduct);
test_plan(&plan, 1, data.as_view());
}
#[test]
fn test_plan_1bit_cosine() {
let (plan, data) = make_impl::<1>(SupportedMetric::Cosine);
test_plan(&plan, 1, data.as_view());
}
#[test]
fn test_plan_2bit_l2() {
let (plan, data) = make_impl::<2>(SupportedMetric::SquaredL2);
test_plan(&plan, 2, data.as_view());
}
#[test]
fn test_plan_2bit_ip() {
let (plan, data) = make_impl::<2>(SupportedMetric::InnerProduct);
test_plan(&plan, 2, data.as_view());
}
#[test]
fn test_plan_2bit_cosine() {
let (plan, data) = make_impl::<2>(SupportedMetric::Cosine);
test_plan(&plan, 2, data.as_view());
}
#[test]
fn test_plan_4bit_l2() {
let (plan, data) = make_impl::<4>(SupportedMetric::SquaredL2);
test_plan(&plan, 4, data.as_view());
}
#[test]
fn test_plan_4bit_ip() {
let (plan, data) = make_impl::<4>(SupportedMetric::InnerProduct);
test_plan(&plan, 4, data.as_view());
}
#[test]
fn test_plan_4bit_cosine() {
let (plan, data) = make_impl::<4>(SupportedMetric::Cosine);
test_plan(&plan, 4, data.as_view());
}
#[test]
fn test_plan_8bit_l2() {
let (plan, data) = make_impl::<8>(SupportedMetric::SquaredL2);
test_plan(&plan, 8, data.as_view());
}
#[test]
fn test_plan_8bit_ip() {
let (plan, data) = make_impl::<8>(SupportedMetric::InnerProduct);
test_plan(&plan, 8, data.as_view());
}
#[test]
fn test_plan_8bit_cosine() {
let (plan, data) = make_impl::<8>(SupportedMetric::Cosine);
test_plan(&plan, 8, data.as_view());
}
fn test_dataset() -> Matrix<f32> {
let data = vec![
0.28657,
-0.0318168,
0.0666847,
0.0329265,
-0.00829283,
0.168735,
-0.000846311,
-0.360779, -0.0968938,
0.161921,
-0.0979579,
0.102228,
-0.259928,
-0.139634,
0.165384,
-0.293443, 0.130205,
0.265737,
0.401816,
-0.407552,
0.13012,
-0.0475244,
0.511723,
-0.4372, -0.0979126,
0.135861,
-0.0154144,
-0.14047,
-0.0250029,
-0.190279,
0.407283,
-0.389184, -0.264153,
0.0696822,
-0.145585,
0.370284,
0.186825,
-0.140736,
0.274703,
-0.334563, 0.247613,
0.513165,
-0.0845867,
0.0532264,
-0.00480601,
-0.122408,
0.47227,
-0.268301, 0.103198,
0.30756,
-0.316293,
-0.0686877,
-0.330729,
-0.461997,
0.550857,
-0.240851, 0.128258,
0.786291,
-0.0268103,
0.111763,
-0.308962,
-0.17407,
0.437154,
-0.159879, 0.00374063,
0.490301,
0.0327826,
-0.0340962,
-0.118605,
0.163879,
0.2737,
-0.299942, -0.284077,
0.249377,
-0.0307734,
-0.0661631,
0.233854,
0.427987,
0.614132,
-0.288649, -0.109492,
0.203939,
-0.73956,
-0.130748,
0.22072,
0.0647836,
0.328726,
-0.374602, -0.223114,
0.0243489,
0.109195,
-0.416914,
0.0201052,
-0.0190542,
0.947078,
-0.333229, -0.165869,
-0.00296729,
-0.414378,
0.231321,
0.205365,
0.161761,
0.148608,
-0.395063, -0.0498255,
0.193279,
-0.110946,
-0.181174,
-0.274578,
-0.227511,
0.190208,
-0.256174, -0.188106,
-0.0292958,
0.0930939,
0.0558456,
0.257437,
0.685481,
0.307922,
-0.320006, 0.250035,
0.275942,
-0.0856306,
-0.352027,
-0.103509,
-0.00890859,
0.276121,
-0.324718, ];
Matrix::try_from(data.into(), 16, 8).unwrap()
}
#[cfg(feature = "flatbuffers")]
mod serialization {
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering},
};
use super::*;
use crate::alloc::{BumpAllocator, GlobalAllocator};
#[inline(never)]
fn test_plan_serialization(
quantizer: &dyn Quantizer,
nbits: usize,
dataset: MatrixView<f32>,
) {
assert_eq!(quantizer.full_dim(), dataset.ncols());
let scoped_global = ScopedAllocator::global();
let serialized = quantizer.serialize(GlobalAllocator).unwrap();
let deserialized =
try_deserialize::<GlobalAllocator, _>(&serialized, GlobalAllocator).unwrap();
assert_eq!(deserialized.nbits(), nbits);
assert_eq!(deserialized.bytes(), quantizer.bytes());
assert_eq!(deserialized.dim(), quantizer.dim());
assert_eq!(deserialized.full_dim(), quantizer.full_dim());
assert_eq!(deserialized.metric(), quantizer.metric());
for layout in QueryLayout::all() {
assert_eq!(
deserialized.is_supported(layout),
quantizer.is_supported(layout)
);
}
let alloc = AlignedAllocator::new(PowerOfTwo::new(4).unwrap());
{
let mut a = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
let mut b = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
for row in dataset.row_iter() {
quantizer
.compress(row, OpaqueMut::new(&mut a), scoped_global)
.unwrap();
deserialized
.compress(row, OpaqueMut::new(&mut b), scoped_global)
.unwrap();
assert_eq!(a, b);
}
}
{
let mut a0 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
let mut a1 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
let mut b0 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
let mut b1 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
let q_computer = quantizer.distance_computer(GlobalAllocator).unwrap();
let q_computer_ref = quantizer.distance_computer_ref();
let d_computer = deserialized.distance_computer(GlobalAllocator).unwrap();
let d_computer_ref = deserialized.distance_computer_ref();
for r0 in dataset.row_iter() {
quantizer
.compress(r0, OpaqueMut::new(&mut a0), scoped_global)
.unwrap();
deserialized
.compress(r0, OpaqueMut::new(&mut b0), scoped_global)
.unwrap();
for r1 in dataset.row_iter() {
quantizer
.compress(r1, OpaqueMut::new(&mut a1), scoped_global)
.unwrap();
deserialized
.compress(r1, OpaqueMut::new(&mut b1), scoped_global)
.unwrap();
let a0 = Opaque::new(&a0);
let a1 = Opaque::new(&a1);
let q_computer_dist = q_computer.evaluate_similarity(a0, a1).unwrap();
let d_computer_dist = d_computer.evaluate_similarity(a0, a1).unwrap();
assert_eq!(q_computer_dist, d_computer_dist);
let q_computer_ref_dist = q_computer_ref.evaluate(a0, a1).unwrap();
assert_eq!(q_computer_dist, q_computer_ref_dist);
let d_computer_ref_dist = d_computer_ref.evaluate(a0, a1).unwrap();
assert_eq!(d_computer_dist, d_computer_ref_dist);
}
}
}
{
let mut a = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
let mut b = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
for layout in QueryLayout::all() {
if !quantizer.is_supported(layout) {
continue;
}
for r in dataset.row_iter() {
let q_computer = quantizer
.fused_query_computer(r, layout, false, GlobalAllocator, scoped_global)
.unwrap();
let d_computer = deserialized
.fused_query_computer(r, layout, false, GlobalAllocator, scoped_global)
.unwrap();
for u in dataset.row_iter() {
quantizer
.compress(u, OpaqueMut::new(&mut a), scoped_global)
.unwrap();
deserialized
.compress(u, OpaqueMut::new(&mut b), scoped_global)
.unwrap();
assert_eq!(
q_computer.evaluate_similarity(Opaque::new(&a)).unwrap(),
d_computer.evaluate_similarity(Opaque::new(&b)).unwrap(),
);
}
}
}
}
}
#[derive(Debug, Clone)]
struct FlakyAllocator {
have_allocated: Arc<AtomicBool>,
}
impl FlakyAllocator {
fn new(have_allocated: Arc<AtomicBool>) -> Self {
Self { have_allocated }
}
}
unsafe impl AllocatorCore for FlakyAllocator {
fn allocate(
&self,
layout: std::alloc::Layout,
) -> Result<std::ptr::NonNull<[u8]>, AllocatorError> {
if self.have_allocated.swap(true, Ordering::Relaxed) {
Err(AllocatorError)
} else {
GlobalAllocator.allocate(layout)
}
}
unsafe fn deallocate(&self, ptr: std::ptr::NonNull<[u8]>, layout: std::alloc::Layout) {
unsafe { GlobalAllocator.deallocate(ptr, layout) }
}
}
fn test_plan_panic_boundary<const NBITS: usize>(v: &Impl<NBITS>)
where
Impl<NBITS>: Quantizer,
{
let have_allocated = Arc::new(AtomicBool::new(false));
let _: AllocatorError = v
.serialize(FlakyAllocator::new(have_allocated.clone()))
.unwrap_err();
assert!(have_allocated.load(Ordering::Relaxed));
}
#[test]
fn test_plan_1bit_l2() {
let (plan, data) = make_impl::<1>(SupportedMetric::SquaredL2);
test_plan_panic_boundary(&plan);
test_plan_serialization(&plan, 1, data.as_view());
}
#[test]
fn test_plan_1bit_ip() {
let (plan, data) = make_impl::<1>(SupportedMetric::InnerProduct);
test_plan_panic_boundary(&plan);
test_plan_serialization(&plan, 1, data.as_view());
}
#[test]
fn test_plan_2bit_l2() {
let (plan, data) = make_impl::<2>(SupportedMetric::SquaredL2);
test_plan_panic_boundary(&plan);
test_plan_serialization(&plan, 2, data.as_view());
}
#[test]
fn test_plan_2bit_ip() {
let (plan, data) = make_impl::<2>(SupportedMetric::InnerProduct);
test_plan_panic_boundary(&plan);
test_plan_serialization(&plan, 2, data.as_view());
}
#[test]
fn test_plan_4bit_l2() {
let (plan, data) = make_impl::<4>(SupportedMetric::SquaredL2);
test_plan_panic_boundary(&plan);
test_plan_serialization(&plan, 4, data.as_view());
}
#[test]
fn test_plan_4bit_ip() {
let (plan, data) = make_impl::<4>(SupportedMetric::InnerProduct);
test_plan_panic_boundary(&plan);
test_plan_serialization(&plan, 4, data.as_view());
}
#[test]
fn test_plan_8bit_l2() {
let (plan, data) = make_impl::<8>(SupportedMetric::SquaredL2);
test_plan_panic_boundary(&plan);
test_plan_serialization(&plan, 8, data.as_view());
}
#[test]
fn test_plan_8bit_ip() {
let (plan, data) = make_impl::<8>(SupportedMetric::InnerProduct);
test_plan_panic_boundary(&plan);
test_plan_serialization(&plan, 8, data.as_view());
}
#[test]
fn test_plan_1bit_cosine() {
let (plan, data) = make_impl::<1>(SupportedMetric::Cosine);
test_plan_panic_boundary(&plan);
test_plan_serialization(&plan, 1, data.as_view());
}
#[test]
fn test_plan_2bit_cosine() {
let (plan, data) = make_impl::<2>(SupportedMetric::Cosine);
test_plan_panic_boundary(&plan);
test_plan_serialization(&plan, 2, data.as_view());
}
#[test]
fn test_plan_4bit_cosine() {
let (plan, data) = make_impl::<4>(SupportedMetric::Cosine);
test_plan_panic_boundary(&plan);
test_plan_serialization(&plan, 4, data.as_view());
}
#[test]
fn test_plan_8bit_cosine() {
let (plan, data) = make_impl::<8>(SupportedMetric::Cosine);
test_plan_panic_boundary(&plan);
test_plan_serialization(&plan, 8, data.as_view());
}
#[test]
fn test_allocation_order() {
let (plan, _) = make_impl::<1>(SupportedMetric::SquaredL2);
let buf = plan.serialize(GlobalAllocator).unwrap();
let allocator = BumpAllocator::new(8192, PowerOfTwo::new(64).unwrap()).unwrap();
let deserialized =
try_deserialize::<GlobalAllocator, _>(&buf, allocator.clone()).unwrap();
assert_eq!(
Poly::as_ptr(&deserialized).cast::<u8>(),
allocator.as_ptr(),
"expected the returned box to be allocated first",
);
}
}
#[cfg(feature = "flatbuffers")]
mod compatibility {
use std::path::PathBuf;
use serde::{Deserialize, Serialize};
use super::*;
use crate::test_util::Check;
const TRAINING_SEED: u64 = 0x7d535118722ff197;
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum Metric {
SquaredL2,
InnerProduct,
Cosine,
}
impl std::fmt::Display for Metric {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
Self::SquaredL2 => "squared_l2",
Self::InnerProduct => "inner_product",
Self::Cosine => "cosine",
};
write!(f, "{}", s)
}
}
impl From<SupportedMetric> for Metric {
fn from(m: SupportedMetric) -> Self {
match m {
SupportedMetric::SquaredL2 => Self::SquaredL2,
SupportedMetric::InnerProduct => Self::InnerProduct,
SupportedMetric::Cosine => Self::Cosine,
}
}
}
impl From<Metric> for SupportedMetric {
fn from(m: Metric) -> Self {
match m {
Metric::SquaredL2 => Self::SquaredL2,
Metric::InnerProduct => Self::InnerProduct,
Metric::Cosine => Self::Cosine,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum DataTransform {
PaddingHadamard,
DoubleHadamard,
Null,
}
impl std::fmt::Display for DataTransform {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
Self::PaddingHadamard => "padding_hadamard",
Self::DoubleHadamard => "double_hadamard",
Self::Null => "null",
};
write!(f, "{}", s)
}
}
impl DataTransform {
fn to_transform_kind(self) -> TransformKind {
match self {
Self::PaddingHadamard => TransformKind::PaddingHadamard {
target_dim: TargetDim::Natural,
},
Self::DoubleHadamard => TransformKind::DoubleHadamard {
target_dim: TargetDim::Natural,
},
Self::Null => TransformKind::Null,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum Layout {
SameAsData,
FourBitTransposed,
ScalarQuantized,
FullPrecision,
}
impl From<QueryLayout> for Layout {
fn from(l: QueryLayout) -> Self {
match l {
QueryLayout::SameAsData => Self::SameAsData,
QueryLayout::FourBitTransposed => Self::FourBitTransposed,
QueryLayout::ScalarQuantized => Self::ScalarQuantized,
QueryLayout::FullPrecision => Self::FullPrecision,
}
}
}
impl From<Layout> for QueryLayout {
fn from(l: Layout) -> Self {
match l {
Layout::SameAsData => Self::SameAsData,
Layout::FourBitTransposed => Self::FourBitTransposed,
Layout::ScalarQuantized => Self::ScalarQuantized,
Layout::FullPrecision => Self::FullPrecision,
}
}
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum ScaleConfig {
#[default]
None,
ReciprocalMeanNorm,
}
impl ScaleConfig {
fn try_as_str(self) -> Option<&'static str> {
match self {
Self::None => Option::None,
Self::ReciprocalMeanNorm => Some("rmn"),
}
}
fn to_prescale(self) -> PreScale {
match self {
Self::None => PreScale::None,
Self::ReciprocalMeanNorm => PreScale::ReciprocalMeanNorm,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Baseline {
nbits: usize,
metric: Metric,
transform: DataTransform,
pre_scale: ScaleConfig,
dim: usize,
full_dim: usize,
training_seed: u64,
serialized_quantizer: Vec<u8>,
compressed_vectors: Vec<Vec<u8>>,
data_distances: Vec<f32>,
query_distances: Vec<LayoutDistances>,
#[serde(default, skip_serializing_if = "Option::is_none")]
rescaled_query_distances: Option<Vec<LayoutDistances>>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
struct LayoutDistances {
layout: Layout,
distances: Vec<Vec<f32>>,
}
fn compress_dataset(quantizer: &dyn Quantizer, dataset: MatrixView<f32>) -> Vec<Vec<u8>> {
let scoped_global = ScopedAllocator::global();
let alloc = AlignedAllocator::new(PowerOfTwo::new(4).unwrap());
dataset
.row_iter()
.map(|row| {
let mut buf = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
quantizer
.compress(row, OpaqueMut::new(&mut buf), scoped_global)
.unwrap();
buf.to_vec()
})
.collect()
}
fn compute_layout_distances(
quantizer: &dyn Quantizer,
dataset: MatrixView<f32>,
compressed: &[Vec<u8>],
allow_rescale: bool,
) -> Vec<LayoutDistances> {
let scoped_global = ScopedAllocator::global();
QueryLayout::all()
.into_iter()
.filter(|&layout| quantizer.is_supported(layout))
.map(|layout| {
let distances = dataset
.row_iter()
.map(|query_row| {
let computer = quantizer
.fused_query_computer(
query_row,
layout,
allow_rescale,
GlobalAllocator,
scoped_global,
)
.unwrap();
compressed
.iter()
.map(|d| computer.evaluate_similarity(Opaque::new(d)).unwrap())
.collect()
})
.collect();
LayoutDistances {
layout: layout.into(),
distances,
}
})
.collect()
}
const TOLERANCE: Check = Check::absrel(1e-6, 0.0);
fn assert_layout_distances(
quantizer: &dyn Quantizer,
dataset: MatrixView<f32>,
compressed: &[Vec<u8>],
expected: &[LayoutDistances],
allow_rescale: bool,
label: &str,
) {
let scoped_global = ScopedAllocator::global();
for layout_distances in expected {
let layout: QueryLayout = layout_distances.layout.into();
assert!(
quantizer.is_supported(layout),
"{label}: unsupported layout {layout:?}"
);
for (qi, (query_row, expected_distances)) in
std::iter::zip(dataset.row_iter(), layout_distances.distances.iter())
.enumerate()
{
let computer = quantizer
.fused_query_computer(
query_row,
layout,
allow_rescale,
GlobalAllocator,
scoped_global,
)
.unwrap();
for (di, (compressed, expected)) in
std::iter::zip(compressed.iter(), expected_distances.iter()).enumerate()
{
let distance = computer
.evaluate_similarity(Opaque::new(compressed))
.unwrap();
if let Err(err) = TOLERANCE.check(distance, *expected) {
panic!("{label}: layout = {layout:?}, query={qi}, data={di}\n{err}")
}
}
}
}
}
fn generate_baseline(
quantizer: &dyn Quantizer,
transform: DataTransform,
pre_scale: ScaleConfig,
dataset: MatrixView<f32>,
) -> Baseline {
let compressed_vectors = compress_dataset(quantizer, dataset);
let f = quantizer.distance_computer(GlobalAllocator).unwrap();
let mut data_distances = Vec::new();
for (i, a) in compressed_vectors.iter().enumerate() {
for b in compressed_vectors.iter().skip(i) {
data_distances.push(
f.evaluate_similarity(Opaque::new(a), Opaque::new(b))
.unwrap(),
);
}
}
let query_distances =
compute_layout_distances(quantizer, dataset, &compressed_vectors, false);
let is_ip = quantizer.metric() == SupportedMetric::InnerProduct;
let rescaled = compute_layout_distances(quantizer, dataset, &compressed_vectors, true);
if !is_ip {
assert_eq!(
query_distances,
rescaled,
"allow_rescale should not affect {:?} distances",
quantizer.metric(),
);
}
let rescaled_query_distances = if is_ip { Some(rescaled) } else { None };
let serialized = quantizer.serialize(GlobalAllocator).unwrap();
Baseline {
nbits: quantizer.nbits(),
metric: quantizer.metric().into(),
transform,
pre_scale,
dim: quantizer.dim(),
full_dim: quantizer.full_dim(),
training_seed: TRAINING_SEED,
serialized_quantizer: serialized.to_vec(),
compressed_vectors,
data_distances,
query_distances,
rescaled_query_distances,
}
}
fn baseline_path(
nbits: usize,
metric: Metric,
transform: DataTransform,
pre_scale: ScaleConfig,
) -> PathBuf {
let manifest_dir = env!("CARGO_MANIFEST_DIR");
let mut name = format!("{}bit_{}_{}", nbits, metric, transform);
if let Some(mangled) = pre_scale.try_as_str() {
name.push('_');
name.push_str(mangled);
}
name.push_str(".json");
PathBuf::from(manifest_dir)
.join("test")
.join("generated")
.join("spherical")
.join(name)
}
fn should_overwrite() -> bool {
std::env::var("DISKANN_TEST")
.map(|v| v == "overwrite")
.unwrap_or(false)
}
fn save_baseline(baseline: &Baseline) {
let path = baseline_path(
baseline.nbits,
baseline.metric,
baseline.transform,
baseline.pre_scale,
);
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).unwrap();
}
let json = serde_json::to_string_pretty(baseline).unwrap();
std::fs::write(&path, json).unwrap();
}
fn load_baseline(
nbits: usize,
metric: Metric,
transform: DataTransform,
pre_scale: ScaleConfig,
) -> Baseline {
let path = baseline_path(nbits, metric, transform, pre_scale);
let json = std::fs::read_to_string(&path).unwrap_or_else(|e| {
panic!(
"Failed to load baseline from {}: {e}\n\
Run with DISKANN_TEST=overwrite to generate baseline files.",
path.display()
);
});
serde_json::from_str(&json).unwrap()
}
fn check_baseline(
baseline: &Baseline,
dataset: MatrixView<f32>,
expected_transform: DataTransform,
expected_pre_scale: ScaleConfig,
) {
assert_eq!(
baseline.training_seed, TRAINING_SEED,
"baseline was generated with a different training seed"
);
assert_eq!(
baseline.transform, expected_transform,
"baseline transform does not match expected transform"
);
assert_eq!(
baseline.pre_scale, expected_pre_scale,
"baseline pre_scale does not match expected pre_scale"
);
let quantizer = try_deserialize::<GlobalAllocator, _>(
&baseline.serialized_quantizer,
GlobalAllocator,
)
.unwrap();
assert_eq!(quantizer.nbits(), baseline.nbits, "nbits mismatch");
assert_eq!(quantizer.dim(), baseline.dim, "dim mismatch");
assert_eq!(quantizer.full_dim(), baseline.full_dim, "full_dim mismatch");
assert_eq!(
quantizer.metric(),
SupportedMetric::from(baseline.metric),
"metric mismatch"
);
let compressed = compress_dataset(&*quantizer, dataset);
assert_eq!(
compressed, baseline.compressed_vectors,
"compressed vectors do not match baseline"
);
let f = quantizer.distance_computer(GlobalAllocator).unwrap();
let n = baseline.compressed_vectors.len();
let expected_len = n * (n + 1) / 2;
assert_eq!(
baseline.data_distances.len(),
expected_len,
"baseline.data_distances has length {}, expected {} for {} compressed vectors",
baseline.data_distances.len(),
expected_len,
n,
);
let mut k = 0;
for (i, a) in baseline.compressed_vectors.iter().enumerate() {
for (j, b) in baseline.compressed_vectors.iter().enumerate().skip(i) {
let distance = f
.evaluate_similarity(Opaque::new(a), Opaque::new(b))
.unwrap();
if let Err(err) = TOLERANCE.check(distance, baseline.data_distances[k]) {
panic!("data distance mismatch at pair ({i}, {j})\n{err}");
}
k += 1;
}
}
assert_layout_distances(
&*quantizer,
dataset,
&baseline.compressed_vectors,
&baseline.query_distances,
false,
"query distance mismatch",
);
if let Some(rescaled) = &baseline.rescaled_query_distances {
assert_eq!(
SupportedMetric::from(baseline.metric),
SupportedMetric::InnerProduct,
"rescaled_query_distances should only be present \
for InnerProduct"
);
assert_layout_distances(
&*quantizer,
dataset,
&baseline.compressed_vectors,
rescaled,
true,
"rescaled query distance mismatch",
);
} else {
assert_layout_distances(
&*quantizer,
dataset,
&baseline.compressed_vectors,
&baseline.query_distances,
true,
"allow_rescale should not affect non-IP distances",
);
}
}
fn make_impl_with<const NBITS: usize>(
metric: SupportedMetric,
transform: DataTransform,
pre_scale: ScaleConfig,
) -> (Impl<NBITS>, Matrix<f32>)
where
Impl<NBITS>: Constructible,
{
let data = test_dataset();
let mut rng = StdRng::seed_from_u64(TRAINING_SEED);
let quantizer = SphericalQuantizer::train(
data.as_view(),
transform.to_transform_kind(),
metric,
pre_scale.to_prescale(),
&mut rng,
GlobalAllocator,
)
.unwrap();
(Impl::<NBITS>::new(quantizer).unwrap(), data)
}
fn run_compatibility_test<const NBITS: usize>(
metric: SupportedMetric,
transform: DataTransform,
pre_scale: ScaleConfig,
) where
Impl<NBITS>: Constructible + Quantizer,
{
let (quantizer, data) = make_impl_with::<NBITS>(metric, transform, pre_scale);
let dataset = data.as_view();
let baseline = if should_overwrite() {
let baseline = generate_baseline(&quantizer, transform, pre_scale, dataset);
save_baseline(&baseline);
baseline
} else {
load_baseline(NBITS, metric.into(), transform, pre_scale)
};
check_baseline(&baseline, dataset, transform, pre_scale);
}
#[test]
fn compat_1bit_l2() {
run_compatibility_test::<1>(
SupportedMetric::SquaredL2,
DataTransform::DoubleHadamard,
ScaleConfig::None,
);
}
#[test]
fn compat_1bit_ip() {
run_compatibility_test::<1>(
SupportedMetric::InnerProduct,
DataTransform::DoubleHadamard,
ScaleConfig::None,
);
}
#[test]
fn compat_1bit_cosine() {
run_compatibility_test::<1>(
SupportedMetric::Cosine,
DataTransform::DoubleHadamard,
ScaleConfig::None,
);
}
#[test]
fn compat_2bit_l2() {
run_compatibility_test::<2>(
SupportedMetric::SquaredL2,
DataTransform::DoubleHadamard,
ScaleConfig::None,
);
}
#[test]
fn compat_2bit_ip() {
run_compatibility_test::<2>(
SupportedMetric::InnerProduct,
DataTransform::DoubleHadamard,
ScaleConfig::None,
);
}
#[test]
fn compat_2bit_cosine() {
run_compatibility_test::<2>(
SupportedMetric::Cosine,
DataTransform::DoubleHadamard,
ScaleConfig::None,
);
}
#[test]
fn compat_4bit_l2() {
run_compatibility_test::<4>(
SupportedMetric::SquaredL2,
DataTransform::DoubleHadamard,
ScaleConfig::None,
);
}
#[test]
fn compat_4bit_ip() {
run_compatibility_test::<4>(
SupportedMetric::InnerProduct,
DataTransform::DoubleHadamard,
ScaleConfig::None,
);
}
#[test]
fn compat_4bit_cosine() {
run_compatibility_test::<4>(
SupportedMetric::Cosine,
DataTransform::DoubleHadamard,
ScaleConfig::None,
);
}
#[test]
fn compat_8bit_l2() {
run_compatibility_test::<8>(
SupportedMetric::SquaredL2,
DataTransform::DoubleHadamard,
ScaleConfig::None,
);
}
#[test]
fn compat_8bit_ip() {
run_compatibility_test::<8>(
SupportedMetric::InnerProduct,
DataTransform::DoubleHadamard,
ScaleConfig::None,
);
}
#[test]
fn compat_8bit_cosine() {
run_compatibility_test::<8>(
SupportedMetric::Cosine,
DataTransform::DoubleHadamard,
ScaleConfig::None,
);
}
#[test]
fn compat_4bit_l2_null() {
run_compatibility_test::<4>(
SupportedMetric::SquaredL2,
DataTransform::Null,
ScaleConfig::None,
);
}
#[test]
fn compat_4bit_l2_padding_hadamard() {
run_compatibility_test::<4>(
SupportedMetric::SquaredL2,
DataTransform::PaddingHadamard,
ScaleConfig::None,
);
}
#[test]
fn compat_4bit_l2_prescale_rmn() {
run_compatibility_test::<4>(
SupportedMetric::SquaredL2,
DataTransform::DoubleHadamard,
ScaleConfig::ReciprocalMeanNorm,
);
}
}
}