use diskann_utils::{Reborrow, ReborrowMut};
use diskann_vector::{Norm, norm::FastL2NormSquared};
use diskann_wide::{Architecture, arch::Target2};
use half::f16;
use thiserror::Error;
#[cfg(feature = "flatbuffers")]
use crate::flatbuffers as fb;
use crate::{
alloc::{AllocatorCore, AllocatorError, Poly},
bits::{BitSlice, Dense, PermutationStrategy, Representation, Unsigned},
distances::{self, InnerProduct, MV},
meta,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SupportedMetric {
SquaredL2,
InnerProduct,
Cosine,
}
#[cfg(test)]
impl SupportedMetric {
fn pick(self, shifted_norm: f32, inner_product_with_centroid: f32) -> f32 {
match self {
Self::SquaredL2 => shifted_norm * shifted_norm,
Self::InnerProduct | Self::Cosine => inner_product_with_centroid,
}
}
#[cfg(feature = "flatbuffers")]
pub(super) fn all() -> [Self; 3] {
[Self::SquaredL2, Self::InnerProduct, Self::Cosine]
}
}
impl TryFrom<diskann_vector::distance::Metric> for SupportedMetric {
type Error = UnsupportedMetric;
fn try_from(metric: diskann_vector::distance::Metric) -> Result<Self, Self::Error> {
use diskann_vector::distance::Metric;
match metric {
Metric::L2 => Ok(Self::SquaredL2),
Metric::InnerProduct => Ok(Self::InnerProduct),
Metric::Cosine => Ok(Self::Cosine),
unsupported => Err(UnsupportedMetric(unsupported)),
}
}
}
impl PartialEq<diskann_vector::distance::Metric> for SupportedMetric {
fn eq(&self, metric: &diskann_vector::distance::Metric) -> bool {
match Self::try_from(*metric) {
Ok(m) => *self == m,
Err(_) => false,
}
}
}
#[derive(Debug, Clone, Copy, Error)]
#[error("metric {0:?} is not supported for spherical quantization")]
pub struct UnsupportedMetric(pub(crate) diskann_vector::distance::Metric);
#[cfg(feature = "flatbuffers")]
#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
#[derive(Debug, Clone, Copy, PartialEq, Error)]
#[error("the value {0} is not recognized as a supported metric")]
pub struct InvalidMetric(i8);
#[cfg(feature = "flatbuffers")]
#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
impl TryFrom<fb::spherical::SupportedMetric> for SupportedMetric {
type Error = InvalidMetric;
fn try_from(value: fb::spherical::SupportedMetric) -> Result<Self, Self::Error> {
match value {
fb::spherical::SupportedMetric::SquaredL2 => Ok(Self::SquaredL2),
fb::spherical::SupportedMetric::InnerProduct => Ok(Self::InnerProduct),
fb::spherical::SupportedMetric::Cosine => Ok(Self::Cosine),
unsupported => Err(InvalidMetric(unsupported.0)),
}
}
}
#[cfg(feature = "flatbuffers")]
#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
impl From<SupportedMetric> for fb::spherical::SupportedMetric {
fn from(value: SupportedMetric) -> Self {
match value {
SupportedMetric::SquaredL2 => fb::spherical::SupportedMetric::SquaredL2,
SupportedMetric::InnerProduct => fb::spherical::SupportedMetric::InnerProduct,
SupportedMetric::Cosine => fb::spherical::SupportedMetric::Cosine,
}
}
}
#[derive(Debug, Default, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
pub struct DataMeta {
pub inner_product_correction: f16,
pub metric_specific: f16,
pub bit_sum: u16,
}
#[derive(Debug, Error, Clone, Copy, PartialEq)]
pub enum DataMetaError {
#[error("inner product correction {value} cannot fit in a 16-bit floating point number")]
InnerProductCorrection { value: f32 },
#[error("metric specific correction {value} cannot fit in a 16-bit floating point number")]
MetricSpecific { value: f32 },
#[error("bit sum {value} cannot fit in a 16-bit unsigned integer")]
BitSum { value: u32 },
}
impl DataMeta {
pub fn new(
inner_product_correction: f32,
metric_specific: f32,
bit_sum: u32,
) -> Result<Self, DataMetaError> {
let inner_product_correction_f16 = diskann_wide::cast_f32_to_f16(inner_product_correction);
if !inner_product_correction_f16.is_finite() {
return Err(DataMetaError::InnerProductCorrection {
value: inner_product_correction,
});
}
let metric_specific_f16 = diskann_wide::cast_f32_to_f16(metric_specific);
if !metric_specific_f16.is_finite() {
return Err(DataMetaError::MetricSpecific {
value: metric_specific,
});
}
let bit_sum_u16: u16 = bit_sum
.try_into()
.map_err(|_| DataMetaError::BitSum { value: bit_sum })?;
Ok(Self {
inner_product_correction: inner_product_correction_f16,
metric_specific: metric_specific_f16,
bit_sum: bit_sum_u16,
})
}
const fn offset_term<const NBITS: usize>() -> f32 {
((2usize).pow(NBITS as u32) as f32 - 1.0) / 2.0
}
#[inline(always)]
pub fn to_full<A>(self, arch: A) -> DataMetaF32
where
A: Architecture,
{
use diskann_wide::SIMDVector;
let pre = [
self.metric_specific,
self.inner_product_correction,
half::f16::default(),
half::f16::default(),
half::f16::default(),
half::f16::default(),
half::f16::default(),
half::f16::default(),
];
let post: <A as Architecture>::f32x8 =
<A as Architecture>::f16x8::from_array(arch, pre).into();
let post = post.to_array();
DataMetaF32 {
metric_specific: post[0],
inner_product_correction: post[1],
bit_sum: self.bit_sum.into(),
}
}
}
#[derive(Debug, Default, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
pub struct DataMetaF32 {
pub inner_product_correction: f32,
pub metric_specific: f32,
pub bit_sum: f32,
}
pub type DataRef<'a, const NBITS: usize> = meta::VectorRef<'a, NBITS, Unsigned, DataMeta>;
pub type DataMut<'a, const NBITS: usize> = meta::VectorMut<'a, NBITS, Unsigned, DataMeta>;
pub type Data<const NBITS: usize, A> = meta::PolyVector<NBITS, Unsigned, DataMeta, Dense, A>;
#[derive(Copy, Clone, Default, Debug, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
pub struct QueryMeta {
pub inner_product_correction: f32,
pub bit_sum: f32,
pub offset: f32,
pub metric_specific: f32,
}
pub type Query<const NBITS: usize, Perm, A> = meta::PolyVector<NBITS, Unsigned, QueryMeta, Perm, A>;
pub type QueryRef<'a, const NBITS: usize, Perm> =
meta::VectorRef<'a, NBITS, Unsigned, QueryMeta, Perm>;
pub type QueryMut<'a, const NBITS: usize, Perm> =
meta::VectorMut<'a, NBITS, Unsigned, QueryMeta, Perm>;
#[derive(Debug, Clone, Copy, Default, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
pub struct FullQueryMeta {
pub sum: f32,
pub shifted_norm: f32,
pub metric_specific: f32,
}
#[derive(Debug)]
pub struct FullQuery<A>
where
A: AllocatorCore,
{
pub data: Poly<[f32], A>,
pub meta: FullQueryMeta,
}
impl<A> FullQuery<A>
where
A: AllocatorCore,
{
pub fn empty(dim: usize, allocator: A) -> Result<Self, AllocatorError> {
Ok(Self {
data: Poly::broadcast(0.0f32, dim, allocator)?,
meta: Default::default(),
})
}
}
pub type FullQueryRef<'a> = meta::slice::SliceRef<'a, f32, FullQueryMeta>;
pub type FullQueryMut<'a> = meta::slice::SliceMut<'a, f32, FullQueryMeta>;
impl<'short, A> Reborrow<'short> for FullQuery<A>
where
A: AllocatorCore,
{
type Target = FullQueryRef<'short>;
fn reborrow(&'short self) -> Self::Target {
FullQueryRef::new(&self.data, &self.meta)
}
}
impl<'short, A> ReborrowMut<'short> for FullQuery<A>
where
A: AllocatorCore,
{
type Target = FullQueryMut<'short>;
fn reborrow_mut(&'short mut self) -> Self::Target {
FullQueryMut::new(&mut self.data, &mut self.meta)
}
}
struct ConstOffset<const NBITS: usize>;
impl<const NBITS: usize> ConstOffset<NBITS> {
const OFFSET: f32 = DataMeta::offset_term::<NBITS>();
const OFFSET_SQUARED: f32 = DataMeta::offset_term::<NBITS>() * DataMeta::offset_term::<NBITS>();
}
#[inline(always)]
fn kernel<A, const NBITS: usize>(
arch: A,
x: DataRef<'_, NBITS>,
y: DataRef<'_, NBITS>,
dim: f32,
) -> distances::Result<f32>
where
A: Architecture,
Unsigned: Representation<NBITS>,
InnerProduct: for<'a> Target2<
A,
distances::MathematicalResult<u32>,
BitSlice<'a, NBITS, Unsigned>,
BitSlice<'a, NBITS, Unsigned>,
>,
{
let ip: distances::MathematicalResult<u32> =
<_ as Target2<_, _, _, _>>::run(InnerProduct, arch, x.vector(), y.vector());
let ip = ip?.into_inner() as f32;
let offset = ConstOffset::<NBITS>::OFFSET;
let offset_squared = ConstOffset::<NBITS>::OFFSET_SQUARED;
let xc = x.meta().to_full(arch);
let yc = y.meta().to_full(arch);
Ok(xc.inner_product_correction
* yc.inner_product_correction
* (ip - offset * (xc.bit_sum + yc.bit_sum) + offset_squared * dim))
}
#[derive(Debug, Clone, Copy)]
pub struct CompensatedSquaredL2 {
pub(super) dim: f32,
}
impl CompensatedSquaredL2 {
pub fn new(dim: usize) -> Self {
Self { dim: dim as f32 }
}
}
impl<A, T, U> Target2<A, distances::Result<f32>, T, U> for CompensatedSquaredL2
where
A: Architecture,
Self: Target2<A, distances::MathematicalResult<f32>, T, U>,
{
#[inline(always)]
fn run(self, arch: A, x: T, y: U) -> distances::Result<f32> {
self.run(arch, x, y).map(|r| r.into_inner())
}
}
impl<A, const NBITS: usize>
Target2<A, distances::MathematicalResult<f32>, DataRef<'_, NBITS>, DataRef<'_, NBITS>>
for CompensatedSquaredL2
where
A: Architecture,
Unsigned: Representation<NBITS>,
InnerProduct: for<'a> Target2<
A,
distances::MathematicalResult<u32>,
BitSlice<'a, NBITS, Unsigned>,
BitSlice<'a, NBITS, Unsigned>,
>,
{
#[inline(always)]
fn run(
self,
arch: A,
x: DataRef<'_, NBITS>,
y: DataRef<'_, NBITS>,
) -> distances::MathematicalResult<f32> {
let xc = x.meta().to_full(arch);
let yc = y.meta().to_full(arch);
let result = xc.metric_specific + yc.metric_specific - 2.0 * kernel(arch, x, y, self.dim)?;
Ok(MV::new(result))
}
}
impl<A, const Q: usize, const D: usize, Perm>
Target2<A, distances::MathematicalResult<f32>, QueryRef<'_, Q, Perm>, DataRef<'_, D>>
for CompensatedSquaredL2
where
A: Architecture,
Unsigned: Representation<Q>,
Unsigned: Representation<D>,
Perm: PermutationStrategy<Q>,
for<'a> InnerProduct: Target2<
A,
distances::MathematicalResult<u32>,
BitSlice<'a, Q, Unsigned, Perm>,
BitSlice<'a, D, Unsigned>,
>,
{
#[inline(always)]
fn run(
self,
arch: A,
x: QueryRef<'_, Q, Perm>,
y: DataRef<'_, D>,
) -> distances::MathematicalResult<f32> {
let ip: distances::MathematicalResult<u32> =
arch.run2_inline(InnerProduct, x.vector(), y.vector());
let ip = ip?.into_inner() as f32;
let yc = y.meta().to_full(arch);
let xc = x.meta();
let y_offset: f32 = DataMeta::offset_term::<D>();
let corrected_ip = yc.inner_product_correction
* xc.inner_product_correction
* (ip - y_offset * xc.bit_sum + xc.offset * yc.bit_sum
- y_offset * xc.offset * self.dim);
Ok(MV::new(
yc.metric_specific + xc.metric_specific - 2.0 * corrected_ip,
))
}
}
impl<A, const NBITS: usize>
Target2<A, distances::MathematicalResult<f32>, FullQueryRef<'_>, DataRef<'_, NBITS>>
for CompensatedSquaredL2
where
A: Architecture,
Unsigned: Representation<NBITS>,
InnerProduct: for<'a> Target2<
A,
distances::MathematicalResult<f32>,
&'a [f32],
BitSlice<'a, NBITS, Unsigned>,
>,
{
#[inline(always)]
fn run(
self,
arch: A,
x: FullQueryRef<'_>,
y: DataRef<'_, NBITS>,
) -> distances::MathematicalResult<f32> {
let s = arch
.run2(InnerProduct, x.vector(), y.vector())?
.into_inner();
let xc = x.meta();
let yc = y.meta().to_full(arch);
let offset = ConstOffset::<NBITS>::OFFSET;
let ip = s - xc.sum * offset;
let r = xc.metric_specific + yc.metric_specific
- 2.0 * xc.shifted_norm * yc.inner_product_correction * ip;
Ok(MV::new(r))
}
}
#[derive(Debug, Clone, Copy)]
pub struct CompensatedIP {
pub(super) squared_shift_norm: f32,
pub(super) dim: f32,
}
impl CompensatedIP {
pub fn new(shift: &[f32], dim: usize) -> Self {
Self {
squared_shift_norm: FastL2NormSquared.evaluate(shift),
dim: dim as f32,
}
}
}
impl<A, T, U> Target2<A, distances::Result<f32>, T, U> for CompensatedIP
where
A: Architecture,
Self: Target2<A, distances::MathematicalResult<f32>, T, U>,
{
#[inline(always)]
fn run(self, arch: A, x: T, y: U) -> distances::Result<f32> {
arch.run2(self, x, y).map(|r| -r.into_inner())
}
}
impl<A, const NBITS: usize>
Target2<A, distances::MathematicalResult<f32>, DataRef<'_, NBITS>, DataRef<'_, NBITS>>
for CompensatedIP
where
A: Architecture,
Unsigned: Representation<NBITS>,
InnerProduct: for<'a> Target2<
A,
distances::MathematicalResult<u32>,
BitSlice<'a, NBITS, Unsigned>,
BitSlice<'a, NBITS, Unsigned>,
>,
{
#[inline(always)]
fn run(
self,
arch: A,
x: DataRef<'_, NBITS>,
y: DataRef<'_, NBITS>,
) -> distances::MathematicalResult<f32> {
let xc = x.meta().to_full(arch);
let yc = y.meta().to_full(arch);
let result = xc.metric_specific
+ yc.metric_specific
+ kernel(arch, x, y, self.dim)?
+ self.squared_shift_norm;
Ok(MV::new(result))
}
}
impl<A, const Q: usize, const D: usize, Perm>
Target2<A, distances::MathematicalResult<f32>, QueryRef<'_, Q, Perm>, DataRef<'_, D>>
for CompensatedIP
where
A: Architecture,
Unsigned: Representation<Q>,
Unsigned: Representation<D>,
Perm: PermutationStrategy<Q>,
for<'a> InnerProduct: Target2<
A,
distances::MathematicalResult<u32>,
BitSlice<'a, Q, Unsigned, Perm>,
BitSlice<'a, D, Unsigned>,
>,
{
#[inline(always)]
fn run(
self,
arch: A,
x: QueryRef<'_, Q, Perm>,
y: DataRef<'_, D>,
) -> distances::MathematicalResult<f32> {
let ip: MV<u32> = arch.run2_inline(InnerProduct, x.vector(), y.vector())?;
let yc = y.meta().to_full(arch);
let xc = x.meta();
let y_offset: f32 = DataMeta::offset_term::<D>();
let corrected_ip = xc.inner_product_correction
* yc.inner_product_correction
* (ip.into_inner() as f32 - y_offset * xc.bit_sum + xc.offset * yc.bit_sum
- y_offset * xc.offset * self.dim);
Ok(MV::new(
corrected_ip + yc.metric_specific + xc.metric_specific + self.squared_shift_norm,
))
}
}
impl<A, const NBITS: usize>
Target2<A, distances::MathematicalResult<f32>, FullQueryRef<'_>, DataRef<'_, NBITS>>
for CompensatedIP
where
A: Architecture,
Unsigned: Representation<NBITS>,
InnerProduct: for<'a> Target2<
A,
distances::MathematicalResult<f32>,
&'a [f32],
BitSlice<'a, NBITS, Unsigned>,
>,
{
#[inline(always)]
fn run(
self,
arch: A,
x: FullQueryRef<'_>,
y: DataRef<'_, NBITS>,
) -> distances::MathematicalResult<f32> {
let s = arch
.run2(InnerProduct, x.vector(), y.vector())?
.into_inner();
let yc = y.meta().to_full(arch);
let xc = x.meta();
let offset = ConstOffset::<NBITS>::OFFSET;
let ip = xc.shifted_norm * yc.inner_product_correction * (s - xc.sum * offset);
Ok(MV::new(
ip + xc.metric_specific + yc.metric_specific + self.squared_shift_norm,
))
}
}
#[derive(Debug, Clone, Copy)]
pub struct CompensatedCosine {
pub(super) inner: CompensatedIP,
}
impl CompensatedCosine {
pub fn new(inner: CompensatedIP) -> Self {
Self { inner }
}
}
impl<A, T, U> Target2<A, distances::MathematicalResult<f32>, T, U> for CompensatedCosine
where
A: Architecture,
CompensatedIP: Target2<A, distances::MathematicalResult<f32>, T, U>,
{
#[inline(always)]
fn run(self, arch: A, x: T, y: U) -> distances::MathematicalResult<f32> {
self.inner.run(arch, x, y)
}
}
impl<A, T, U> Target2<A, distances::Result<f32>, T, U> for CompensatedCosine
where
A: Architecture,
Self: Target2<A, distances::MathematicalResult<f32>, T, U>,
{
#[inline(always)]
fn run(self, arch: A, x: T, y: U) -> distances::Result<f32> {
let r: MV<f32> = self.run(arch, x, y)?;
Ok(1.0 - r.into_inner())
}
}
#[cfg(test)]
mod tests {
use diskann_utils::{Reborrow, lazy_format};
use diskann_vector::{PureDistanceFunction, distance::Metric, norm::FastL2Norm};
use diskann_wide::ARCH;
use rand::{
SeedableRng,
distr::{Distribution, Uniform},
rngs::StdRng,
};
use rand_distr::StandardNormal;
use super::*;
use crate::{
alloc::GlobalAllocator,
bits::{BitTranspose, Dense},
};
#[derive(Debug, Clone, Copy, PartialEq)]
struct Approx {
absolute: f32,
relative: f32,
}
impl Approx {
const fn new(absolute: f32, relative: f32) -> Self {
assert!(absolute >= 0.0);
assert!(relative >= 0.0);
Self { absolute, relative }
}
fn check(&self, got: f32, expected: f32, ctx: Option<&dyn std::fmt::Display>) -> bool {
struct Ctx<'a>(Option<&'a dyn std::fmt::Display>);
impl std::fmt::Display for Ctx<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.0 {
None => write!(f, "none"),
Some(d) => write!(f, "{}", d),
}
}
}
let absolute = (got - expected).abs();
if absolute <= self.absolute {
true
} else {
let relative = absolute / expected.abs();
if relative <= self.relative {
true
} else {
panic!(
"got {}, expected {}. Abs/Rel = {}/{} with bounds {}/{}: Ctx: {}",
got,
expected,
absolute,
relative,
self.absolute,
self.relative,
Ctx(ctx)
);
}
}
}
}
#[test]
fn test_data_meta() {
let meta = DataMeta::new(1.0, 2.0, 10).unwrap();
let expected = DataMetaF32 {
inner_product_correction: 1.0,
metric_specific: 2.0,
bit_sum: 10.0,
};
assert_eq!(meta.to_full(ARCH), expected);
let err = DataMeta::new(65600.0, 2.0, 10).unwrap_err();
assert_eq!(
err.to_string(),
"inner product correction 65600 cannot fit in a 16-bit floating point number"
);
let err = DataMeta::new(2.0, 65600.0, 10).unwrap_err();
assert_eq!(
err.to_string(),
"metric specific correction 65600 cannot fit in a 16-bit floating point number"
);
let err = DataMeta::new(2.0, 2.0, 65536).unwrap_err();
assert_eq!(
err.to_string(),
"bit sum 65536 cannot fit in a 16-bit unsigned integer",
);
}
#[test]
fn supported_metric() {
assert_eq!(
SupportedMetric::try_from(Metric::L2).unwrap(),
SupportedMetric::SquaredL2
);
assert_eq!(
SupportedMetric::try_from(Metric::InnerProduct).unwrap(),
SupportedMetric::InnerProduct
);
assert_eq!(
SupportedMetric::try_from(Metric::Cosine).unwrap(),
SupportedMetric::Cosine
);
assert!(matches!(
SupportedMetric::try_from(Metric::CosineNormalized),
Err(UnsupportedMetric(Metric::CosineNormalized))
));
assert_eq!(SupportedMetric::SquaredL2, Metric::L2);
assert_ne!(SupportedMetric::SquaredL2, Metric::InnerProduct);
assert_ne!(SupportedMetric::SquaredL2, Metric::Cosine);
assert_ne!(SupportedMetric::SquaredL2, Metric::CosineNormalized);
assert_ne!(SupportedMetric::InnerProduct, Metric::L2);
assert_eq!(SupportedMetric::InnerProduct, Metric::InnerProduct);
assert_ne!(SupportedMetric::SquaredL2, Metric::Cosine);
assert_ne!(SupportedMetric::SquaredL2, Metric::CosineNormalized);
}
struct Reference<T> {
compressed: T,
reconstructed: Vec<f32>,
norm: f32,
center_ip: f32,
self_ip: Option<f32>,
}
trait GenerateReference: Sized {
fn generate_reference(
center: &[f32],
metric: SupportedMetric,
rng: &mut StdRng,
) -> Reference<Self>;
}
impl<const NBITS: usize> GenerateReference for Data<NBITS, GlobalAllocator>
where
Unsigned: Representation<NBITS>,
{
fn generate_reference(
center: &[f32],
metric: SupportedMetric,
rng: &mut StdRng,
) -> Reference<Self> {
let dim = center.len();
let mut reconstructed = vec![0.0f32; dim];
let mut compressed = Data::<NBITS, _>::new_boxed(dim);
let mut bit_sum = 0;
let dist = Uniform::try_from(Unsigned::domain_const::<NBITS>()).unwrap();
let offset = (2usize.pow(NBITS as u32) as f32 - 1.0) / 2.0;
for (i, r) in reconstructed.iter_mut().enumerate() {
let b: i64 = dist.sample(rng);
bit_sum += b;
compressed.vector_mut().set(i, b).unwrap();
*r = (b as f32) - offset;
}
let r_norm = FastL2Norm.evaluate(reconstructed.as_slice());
reconstructed.iter_mut().for_each(|i| *i /= r_norm);
let norm: f32 = Uniform::new(0.0, 2.0).unwrap().sample(rng);
let center_ip: f32 = Uniform::new(0.5, 2.5).unwrap().sample(rng);
let self_ip: f32 = Uniform::new(0.5, 1.5).unwrap().sample(rng);
compressed.set_meta(
DataMeta::new(
norm / (self_ip * r_norm),
metric.pick(norm, center_ip),
bit_sum.try_into().unwrap(),
)
.unwrap(),
);
Reference {
compressed,
reconstructed,
norm,
center_ip,
self_ip: Some(self_ip),
}
}
}
impl<const NBITS: usize, Perm> GenerateReference for Query<NBITS, Perm, GlobalAllocator>
where
Unsigned: Representation<NBITS>,
Perm: PermutationStrategy<NBITS>,
{
fn generate_reference(
center: &[f32],
metric: SupportedMetric,
rng: &mut StdRng,
) -> Reference<Self> {
let dim = center.len();
let mut reconstructed = vec![0.0f32; dim];
let mut compressed = Query::<NBITS, Perm, _>::new_boxed(dim);
let distribution = Uniform::try_from(Unsigned::domain_const::<NBITS>()).unwrap();
let base: f32 = StandardNormal {}.sample(rng);
let scale: f32 = {
let scale: f32 = StandardNormal {}.sample(rng);
scale.abs()
};
let mut bit_sum = 0;
for (i, r) in reconstructed.iter_mut().enumerate() {
let b = distribution.sample(rng);
compressed.vector_mut().set(i, b).unwrap();
*r = base + scale * (b as f32);
bit_sum += b;
}
let norm: f32 = Uniform::new(0.0, 2.0).unwrap().sample(rng);
let center_ip: f32 = Uniform::new(-2.0, 2.0).unwrap().sample(rng);
compressed.set_meta(QueryMeta {
inner_product_correction: norm * scale,
bit_sum: bit_sum as f32,
offset: base / scale,
metric_specific: metric.pick(norm, center_ip),
});
Reference {
compressed,
reconstructed,
norm,
center_ip,
self_ip: None,
}
}
}
impl GenerateReference for FullQuery<GlobalAllocator> {
fn generate_reference(
center: &[f32],
metric: SupportedMetric,
rng: &mut StdRng,
) -> Reference<Self> {
let dim = center.len();
let mut query = FullQuery::empty(dim, GlobalAllocator).unwrap();
let mut sum = 0.0;
let dist = StandardNormal {};
for r in query.data.iter_mut() {
let b: f32 = dist.sample(rng);
sum += b;
*r = b;
}
let r_norm = FastL2Norm.evaluate(&*query.data);
query.data.iter_mut().for_each(|i| *i /= r_norm);
let norm: f32 = Uniform::new(0.0, 2.0).unwrap().sample(rng);
let center_ip: f32 = Uniform::new(-2.0, 2.0).unwrap().sample(rng);
query.meta = FullQueryMeta {
sum: sum / r_norm,
shifted_norm: norm,
metric_specific: metric.pick(norm, center_ip),
};
let reconstructed = query.data.to_vec();
Reference {
compressed: query,
reconstructed,
norm,
center_ip,
self_ip: None,
}
}
}
fn test_compensated_distance<const NBITS: usize>(
dim: usize,
ntrials: usize,
err_l2: Approx,
err_ip: Approx,
rng: &mut StdRng,
) where
Unsigned: Representation<NBITS>,
for<'a> CompensatedIP: Target2<
diskann_wide::arch::Current,
distances::Result<f32>,
DataRef<'a, NBITS>,
DataRef<'a, NBITS>,
> + Target2<
diskann_wide::arch::Current,
distances::MathematicalResult<f32>,
DataRef<'a, NBITS>,
DataRef<'a, NBITS>,
>,
for<'a> CompensatedSquaredL2: Target2<
diskann_wide::arch::Current,
distances::Result<f32>,
DataRef<'a, NBITS>,
DataRef<'a, NBITS>,
> + Target2<
diskann_wide::arch::Current,
distances::MathematicalResult<f32>,
DataRef<'a, NBITS>,
DataRef<'a, NBITS>,
>,
{
let mut center = vec![0.0f32; dim];
for trial in 0..ntrials {
center
.iter_mut()
.for_each(|c| *c = StandardNormal {}.sample(rng));
let c_square_norm = FastL2NormSquared.evaluate(&*center);
{
let x = Data::<NBITS, _>::generate_reference(
¢er,
SupportedMetric::InnerProduct,
rng,
);
let y = Data::<NBITS, _>::generate_reference(
¢er,
SupportedMetric::InnerProduct,
rng,
);
let kernel_result = {
let xy: MV<f32> = diskann_vector::distance::InnerProduct::evaluate(
&*x.reconstructed,
&*y.reconstructed,
);
x.norm * y.norm * xy.into_inner() / (x.self_ip.unwrap() * y.self_ip.unwrap())
};
let reference_ip = kernel_result + x.center_ip + y.center_ip + c_square_norm;
let ip = CompensatedIP::new(¢er, center.len());
let got_ip: distances::MathematicalResult<f32> =
ARCH.run2(ip, x.compressed.reborrow(), y.compressed.reborrow());
let got_ip = got_ip.unwrap();
let ctx = &lazy_format!(
"Inner Product, trial {} of {}, dim = {}",
trial,
ntrials,
dim
);
assert!(err_ip.check(got_ip.into_inner(), reference_ip, Some(ctx)));
let got_ip_f32: distances::Result<f32> =
ARCH.run2(ip, x.compressed.reborrow(), y.compressed.reborrow());
let got_ip_f32 = got_ip_f32.unwrap();
assert_eq!(got_ip_f32, -got_ip.into_inner());
let cosine = CompensatedCosine::new(ip);
let got_cosine: distances::MathematicalResult<f32> =
ARCH.run2(cosine, x.compressed.reborrow(), y.compressed.reborrow());
let got_cosine = got_cosine.unwrap();
assert_eq!(
got_cosine.into_inner(),
got_ip.into_inner(),
"cosine and IP should be the same"
);
let got_cosine_f32: distances::Result<f32> =
ARCH.run2(cosine, x.compressed.reborrow(), y.compressed.reborrow());
let got_cosine_f32 = got_cosine_f32.unwrap();
assert_eq!(
got_cosine_f32,
1.0 - got_cosine.into_inner(),
"incorrect transform performed"
);
}
{
let x =
Data::<NBITS, _>::generate_reference(¢er, SupportedMetric::SquaredL2, rng);
let y =
Data::<NBITS, _>::generate_reference(¢er, SupportedMetric::SquaredL2, rng);
let kernel_result = {
let xy: MV<f32> = diskann_vector::distance::InnerProduct::evaluate(
&*x.reconstructed,
&*y.reconstructed,
);
x.norm * y.norm * xy.into_inner() / (x.self_ip.unwrap() * y.self_ip.unwrap())
};
let reference_l2 = x.norm * x.norm + y.norm * y.norm - 2.0 * kernel_result;
let l2 = CompensatedSquaredL2::new(dim);
let got_l2: distances::MathematicalResult<f32> =
ARCH.run2(l2, x.compressed.reborrow(), y.compressed.reborrow());
let got_l2 = got_l2.unwrap();
let ctx =
&lazy_format!("Squared L2, trial {} of {}, dim = {}", trial, ntrials, dim);
assert!(err_l2.check(got_l2.into_inner(), reference_l2, Some(ctx)));
let got_l2_f32: distances::Result<f32> =
ARCH.run2(l2, x.compressed.reborrow(), y.compressed.reborrow());
let got_l2_f32 = got_l2_f32.unwrap();
assert_eq!(got_l2_f32, got_l2.into_inner());
}
}
}
fn test_mixed_compensated_distance<const Q: usize, const D: usize, Perm>(
dim: usize,
ntrials: usize,
err_l2: Approx,
err_ip: Approx,
rng: &mut StdRng,
) where
Unsigned: Representation<Q>,
Unsigned: Representation<D>,
Perm: PermutationStrategy<Q>,
for<'a> CompensatedIP: Target2<
diskann_wide::arch::Current,
distances::MathematicalResult<f32>,
QueryRef<'a, Q, Perm>,
DataRef<'a, D>,
>,
for<'a> CompensatedSquaredL2: Target2<
diskann_wide::arch::Current,
distances::MathematicalResult<f32>,
QueryRef<'a, Q, Perm>,
DataRef<'a, D>,
>,
for<'a> CompensatedCosine: Target2<
diskann_wide::arch::Current,
distances::MathematicalResult<f32>,
QueryRef<'a, Q, Perm>,
DataRef<'a, D>,
>,
for<'a> CompensatedIP: Target2<
diskann_wide::arch::Current,
distances::Result<f32>,
QueryRef<'a, Q, Perm>,
DataRef<'a, D>,
>,
for<'a> CompensatedSquaredL2: Target2<
diskann_wide::arch::Current,
distances::Result<f32>,
QueryRef<'a, Q, Perm>,
DataRef<'a, D>,
>,
for<'a> CompensatedCosine: Target2<
diskann_wide::arch::Current,
distances::Result<f32>,
QueryRef<'a, Q, Perm>,
DataRef<'a, D>,
>,
{
let mut center = vec![0.0f32; dim];
for trial in 0..ntrials {
center
.iter_mut()
.for_each(|c| *c = StandardNormal {}.sample(rng));
let c_square_norm = FastL2NormSquared.evaluate(&*center);
{
let x = Query::<Q, Perm, _>::generate_reference(
¢er,
SupportedMetric::InnerProduct,
rng,
);
let y =
Data::<D, _>::generate_reference(¢er, SupportedMetric::InnerProduct, rng);
let xy = {
let xy: MV<f32> = diskann_vector::distance::InnerProduct::evaluate(
&*x.reconstructed,
&*y.reconstructed,
);
x.norm * y.norm * xy.into_inner() / y.self_ip.unwrap()
};
let reference_ip = -(xy + x.center_ip + y.center_ip + c_square_norm);
let ip = CompensatedIP::new(¢er, center.len());
let got_ip: distances::Result<f32> =
ARCH.run2(ip, x.compressed.reborrow(), y.compressed.reborrow());
let got_ip = got_ip.unwrap();
let ctx = &lazy_format!(
"Inner Product, trial = {} of {}, dim = {}",
trial,
ntrials,
dim
);
assert!(err_ip.check(got_ip, reference_ip, Some(ctx)));
let cosine = CompensatedCosine::new(ip);
let got_cosine: distances::MathematicalResult<f32> =
ARCH.run2(cosine, x.compressed.reborrow(), y.compressed.reborrow());
let got_cosine = got_cosine.unwrap();
assert_eq!(
got_cosine.into_inner(),
-got_ip,
"cosine and IP should be the same"
);
let got_cosine_f32: distances::Result<f32> =
ARCH.run2(cosine, x.compressed.reborrow(), y.compressed.reborrow());
let got_cosine_f32 = got_cosine_f32.unwrap();
assert_eq!(
got_cosine_f32,
1.0 - got_cosine.into_inner(),
"incorrect transform performed"
);
}
{
let x = Query::<Q, Perm, _>::generate_reference(
¢er,
SupportedMetric::SquaredL2,
rng,
);
let y = Data::<D, _>::generate_reference(¢er, SupportedMetric::SquaredL2, rng);
let xy = {
let xy: MV<f32> = diskann_vector::distance::InnerProduct::evaluate(
&*x.reconstructed,
&*y.reconstructed,
);
x.norm * y.norm * xy.into_inner() / y.self_ip.unwrap()
};
let reference_l2 = x.norm * x.norm + y.norm * y.norm - 2.0 * xy;
let l2 = CompensatedSquaredL2::new(dim);
let got_l2: distances::Result<f32> =
ARCH.run2(l2, x.compressed.reborrow(), y.compressed.reborrow());
let got_l2 = got_l2.unwrap();
let ctx = &lazy_format!(
"Squared L2, trial = {} of {}, dim = {}",
trial,
ntrials,
dim
);
assert!(err_l2.check(got_l2, reference_l2, Some(ctx)));
}
}
}
fn test_full_distances<const NBITS: usize>(
dim: usize,
ntrials: usize,
err_l2: Approx,
err_ip: Approx,
rng: &mut StdRng,
) where
Unsigned: Representation<NBITS>,
for<'a> CompensatedIP: Target2<
diskann_wide::arch::Current,
distances::MathematicalResult<f32>,
FullQueryRef<'a>,
DataRef<'a, NBITS>,
>,
for<'a> CompensatedSquaredL2: Target2<
diskann_wide::arch::Current,
distances::MathematicalResult<f32>,
FullQueryRef<'a>,
DataRef<'a, NBITS>,
>,
for<'a> CompensatedCosine: Target2<
diskann_wide::arch::Current,
distances::MathematicalResult<f32>,
FullQueryRef<'a>,
DataRef<'a, NBITS>,
>,
for<'a> CompensatedIP: Target2<
diskann_wide::arch::Current,
distances::Result<f32>,
FullQueryRef<'a>,
DataRef<'a, NBITS>,
>,
for<'a> CompensatedSquaredL2: Target2<
diskann_wide::arch::Current,
distances::Result<f32>,
FullQueryRef<'a>,
DataRef<'a, NBITS>,
>,
for<'a> CompensatedCosine: Target2<
diskann_wide::arch::Current,
distances::Result<f32>,
FullQueryRef<'a>,
DataRef<'a, NBITS>,
>,
{
let mut center = vec![0.0f32; dim];
for trial in 0..ntrials {
center
.iter_mut()
.for_each(|c| *c = StandardNormal {}.sample(rng));
let c_square_norm = FastL2NormSquared.evaluate(&*center);
{
let x = FullQuery::generate_reference(¢er, SupportedMetric::InnerProduct, rng);
let y = Data::<NBITS, _>::generate_reference(
¢er,
SupportedMetric::InnerProduct,
rng,
);
let xy = {
let xy: MV<f32> = diskann_vector::distance::InnerProduct::evaluate(
&*x.reconstructed,
&*y.reconstructed,
);
x.norm * y.norm * xy.into_inner() / y.self_ip.unwrap()
};
let reference_ip = -(xy + x.center_ip + y.center_ip + c_square_norm);
let ip = CompensatedIP::new(¢er, center.len());
let got_ip: distances::Result<f32> =
ARCH.run2(ip, x.compressed.reborrow(), y.compressed.reborrow());
let got_ip = got_ip.unwrap();
let ctx = &lazy_format!(
"Inner Product, trial = {} of {}, dim = {}",
trial,
ntrials,
dim
);
assert!(err_ip.check(got_ip, reference_ip, Some(ctx)));
let cosine = CompensatedCosine::new(ip);
let got_cosine: distances::MathematicalResult<f32> =
ARCH.run2(cosine, x.compressed.reborrow(), y.compressed.reborrow());
let got_cosine = got_cosine.unwrap();
assert_eq!(
got_cosine.into_inner(),
-got_ip,
"cosine and IP should be the same"
);
let got_cosine_f32: distances::Result<f32> =
ARCH.run2(cosine, x.compressed.reborrow(), y.compressed.reborrow());
let got_cosine_f32 = got_cosine_f32.unwrap();
assert_eq!(
got_cosine_f32,
1.0 - got_cosine.into_inner(),
"incorrect transform performed"
);
}
{
let x = FullQuery::generate_reference(¢er, SupportedMetric::SquaredL2, rng);
let y =
Data::<NBITS, _>::generate_reference(¢er, SupportedMetric::SquaredL2, rng);
let xy = {
let xy: MV<f32> = diskann_vector::distance::InnerProduct::evaluate(
&*x.reconstructed,
&*y.reconstructed,
);
x.norm * y.norm * xy.into_inner() / y.self_ip.unwrap()
};
let reference_l2 = x.norm * x.norm + y.norm * y.norm - 2.0 * xy;
let l2 = CompensatedSquaredL2::new(dim);
let got_l2: distances::Result<f32> =
ARCH.run2(l2, x.compressed.reborrow(), y.compressed.reborrow());
let got_l2 = got_l2.unwrap();
let ctx = &lazy_format!(
"Squared L2, trial = {} of {}, dim = {}",
trial,
ntrials,
dim
);
assert!(err_l2.check(got_l2, reference_l2, Some(ctx)));
}
}
}
cfg_if::cfg_if! {
if #[cfg(miri)] {
const MAX_DIM: usize = 37;
const TRIALS_PER_DIM: usize = 1;
} else {
const MAX_DIM: usize = 256;
const TRIALS_PER_DIM: usize = 20;
}
}
#[test]
fn test_symmetric_distances_1bit() {
let mut rng = StdRng::seed_from_u64(0x2a5f79a2469218f6);
for dim in 1..MAX_DIM {
test_compensated_distance::<1>(
dim,
TRIALS_PER_DIM,
Approx::new(4.0e-3, 3.0e-3),
Approx::new(1.0e-3, 5.0e-4),
&mut rng,
);
}
}
#[test]
fn test_symmetric_distances_2bit() {
let mut rng = StdRng::seed_from_u64(0x68f8f52057f94399);
for dim in 1..MAX_DIM {
test_compensated_distance::<2>(
dim,
TRIALS_PER_DIM,
Approx::new(3.5e-3, 2.0e-3),
Approx::new(2.0e-3, 5.0e-4),
&mut rng,
);
}
}
#[test]
fn test_symmetric_distances_4bit() {
let mut rng = StdRng::seed_from_u64(0xb88d76ac4c58e923);
for dim in 1..MAX_DIM {
test_compensated_distance::<4>(
dim,
TRIALS_PER_DIM,
Approx::new(2.0e-3, 2.0e-3),
Approx::new(2.0e-3, 5.0e-4),
&mut rng,
);
}
}
#[test]
fn test_symmetric_distances_8bit() {
let mut rng = StdRng::seed_from_u64(0x1c2b79873ee32626);
for dim in 1..MAX_DIM {
test_compensated_distance::<8>(
dim,
TRIALS_PER_DIM,
Approx::new(2.0e-3, 2.0e-3),
Approx::new(2.0e-3, 4.0e-4),
&mut rng,
);
}
}
#[test]
fn test_mixed_distances_4x1() {
let mut rng = StdRng::seed_from_u64(0x1efb4d87ed0a8ada);
for dim in 1..MAX_DIM {
test_mixed_compensated_distance::<4, 1, BitTranspose>(
dim,
TRIALS_PER_DIM,
Approx::new(4.0e-3, 3.0e-3),
Approx::new(1.3e-2, 8.3e-3),
&mut rng,
);
}
}
#[test]
fn test_mixed_distances_4x4() {
let mut rng = StdRng::seed_from_u64(0x508554264eb7a51b);
for dim in 1..MAX_DIM {
test_mixed_compensated_distance::<4, 4, Dense>(
dim,
TRIALS_PER_DIM,
Approx::new(4.0e-3, 3.0e-3),
Approx::new(3.0e-4, 8.3e-2),
&mut rng,
);
}
}
#[test]
fn test_mixed_distances_8x8() {
let mut rng = StdRng::seed_from_u64(0x8acd8e4224c76c43);
for dim in 1..MAX_DIM {
test_mixed_compensated_distance::<8, 8, Dense>(
dim,
TRIALS_PER_DIM,
Approx::new(2.0e-3, 6.0e-3),
Approx::new(1.0e-2, 3.0e-2),
&mut rng,
);
}
}
#[test]
fn test_full_distances_1bit() {
let mut rng = StdRng::seed_from_u64(0x7f93530559f42d66);
for dim in 1..MAX_DIM {
test_full_distances::<1>(
dim,
TRIALS_PER_DIM,
Approx::new(1.0e-3, 2.0e-3),
Approx::new(0.0, 5.0e-3),
&mut rng,
);
}
}
#[test]
fn test_full_distances_2bit() {
let mut rng = StdRng::seed_from_u64(0xa3ad61d3d03a0c5a);
for dim in 1..MAX_DIM {
test_full_distances::<2>(
dim,
TRIALS_PER_DIM,
Approx::new(2.0e-3, 1.1e-3),
Approx::new(7.0e-4, 1.0e-3),
&mut rng,
);
}
}
#[test]
fn test_full_distances_4bit() {
let mut rng = StdRng::seed_from_u64(0x3e2f50ed7c64f0c2);
for dim in 1..MAX_DIM {
test_full_distances::<4>(
dim,
TRIALS_PER_DIM,
Approx::new(2.0e-3, 1.0e-2),
Approx::new(1.0e-3, 5.0e-4),
&mut rng,
);
}
}
#[test]
fn test_full_distances_8bit() {
let mut rng = StdRng::seed_from_u64(0x95705070e415c6d3);
for dim in 1..MAX_DIM {
test_full_distances::<8>(
dim,
TRIALS_PER_DIM,
Approx::new(1.0e-3, 1.0e-3),
Approx::new(2.0e-3, 1.0e-4),
&mut rng,
);
}
}
}