use diskann_vector::{DistanceFunction, PureDistanceFunction};
use super::inverse_bit_scale;
use crate::{
bits::{BitSlice, Dense, Representation, Unsigned},
distances::{self, InnerProduct, MV, SquaredL2, check_lengths},
meta,
};
#[derive(Default, Debug, Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(transparent)]
pub struct Compensation(pub f32);
pub type CompensatedVectorRef<'a, const NBITS: usize, Perm = Dense> =
meta::VectorRef<'a, NBITS, Unsigned, Compensation, Perm>;
pub type MutCompensatedVectorRef<'a, const NBITS: usize, Perm = Dense> =
meta::VectorMut<'a, NBITS, Unsigned, Compensation, Perm>;
pub type CompensatedVector<const NBITS: usize, Perm = Dense> =
meta::Vector<NBITS, Unsigned, Compensation, Perm>;
#[derive(Debug, Clone, Copy)]
pub struct CompensatedSquaredL2 {
pub(super) scale_squared: f32,
}
impl CompensatedSquaredL2 {
pub fn new(scale_squared: f32) -> Self {
Self { scale_squared }
}
}
impl<const NBITS: usize>
DistanceFunction<
CompensatedVectorRef<'_, NBITS>,
CompensatedVectorRef<'_, NBITS>,
distances::MathematicalResult<f32>,
> for CompensatedSquaredL2
where
Unsigned: Representation<NBITS>,
SquaredL2: for<'a, 'b> PureDistanceFunction<
BitSlice<'a, NBITS, Unsigned>,
BitSlice<'b, NBITS, Unsigned>,
distances::MathematicalResult<u32>,
>,
{
fn evaluate_similarity(
&self,
x: CompensatedVectorRef<'_, NBITS>,
y: CompensatedVectorRef<'_, NBITS>,
) -> distances::MathematicalResult<f32> {
check_lengths!(x, y)?;
let squared_l2: distances::MathematicalResult<u32> =
SquaredL2::evaluate(x.vector(), y.vector());
let squared_l2 = squared_l2?.into_inner() as f32;
let bit_scale = inverse_bit_scale::<NBITS>() * inverse_bit_scale::<NBITS>();
let result = bit_scale * self.scale_squared * squared_l2;
Ok(MV::new(result))
}
}
impl<const NBITS: usize>
DistanceFunction<
CompensatedVectorRef<'_, NBITS>,
CompensatedVectorRef<'_, NBITS>,
distances::Result<f32>,
> for CompensatedSquaredL2
where
Unsigned: Representation<NBITS>,
Self: for<'a, 'b> DistanceFunction<
CompensatedVectorRef<'a, NBITS>,
CompensatedVectorRef<'b, NBITS>,
distances::MathematicalResult<f32>,
>,
{
fn evaluate_similarity(
&self,
x: CompensatedVectorRef<'_, NBITS>,
y: CompensatedVectorRef<'_, NBITS>,
) -> distances::Result<f32> {
let v: MV<f32> = self.evaluate_similarity(x, y)?;
Ok(v.into_inner())
}
}
#[derive(Debug, Clone, Copy)]
pub struct CompensatedIP {
pub(super) scale_squared: f32,
pub(super) shift_square_norm: f32,
}
impl CompensatedIP {
pub fn new(scale_squared: f32, shift_square_norm: f32) -> Self {
Self {
scale_squared,
shift_square_norm,
}
}
}
impl<const NBITS: usize>
DistanceFunction<
CompensatedVectorRef<'_, NBITS>,
CompensatedVectorRef<'_, NBITS>,
distances::MathematicalResult<f32>,
> for CompensatedIP
where
Unsigned: Representation<NBITS>,
InnerProduct: for<'a, 'b> PureDistanceFunction<
BitSlice<'a, NBITS, Unsigned>,
BitSlice<'b, NBITS, Unsigned>,
distances::MathematicalResult<u32>,
>,
{
fn evaluate_similarity(
&self,
x: CompensatedVectorRef<'_, NBITS>,
y: CompensatedVectorRef<'_, NBITS>,
) -> distances::MathematicalResult<f32> {
let product: MV<u32> = InnerProduct::evaluate(x.vector(), y.vector())?;
let bit_scale = inverse_bit_scale::<NBITS>() * inverse_bit_scale::<NBITS>();
let result = (bit_scale * self.scale_squared)
.mul_add(product.into_inner() as f32, self.shift_square_norm)
+ (y.meta().0 + x.meta().0);
Ok(MV::new(result))
}
}
impl<const NBITS: usize>
DistanceFunction<
CompensatedVectorRef<'_, NBITS>,
CompensatedVectorRef<'_, NBITS>,
distances::Result<f32>,
> for CompensatedIP
where
Unsigned: Representation<NBITS>,
Self: for<'a, 'b> DistanceFunction<
CompensatedVectorRef<'a, NBITS>,
CompensatedVectorRef<'b, NBITS>,
distances::MathematicalResult<f32>,
>,
{
fn evaluate_similarity(
&self,
x: CompensatedVectorRef<'_, NBITS>,
y: CompensatedVectorRef<'_, NBITS>,
) -> distances::Result<f32> {
let v: MV<f32> = self.evaluate_similarity(x, y)?;
Ok(-v.into_inner())
}
}
#[derive(Debug, Clone, Copy)]
pub struct CompensatedCosineNormalized {
pub(super) scale_squared: f32,
}
impl CompensatedCosineNormalized {
pub fn new(scale_squared: f32) -> Self {
Self { scale_squared }
}
}
impl<const NBITS: usize>
DistanceFunction<
CompensatedVectorRef<'_, NBITS>,
CompensatedVectorRef<'_, NBITS>,
distances::MathematicalResult<f32>,
> for CompensatedCosineNormalized
where
Unsigned: Representation<NBITS>,
SquaredL2: for<'a, 'b> PureDistanceFunction<
BitSlice<'a, NBITS, Unsigned>,
BitSlice<'b, NBITS, Unsigned>,
distances::MathematicalResult<u32>,
>,
{
fn evaluate_similarity(
&self,
x: CompensatedVectorRef<'_, NBITS>,
y: CompensatedVectorRef<'_, NBITS>,
) -> distances::MathematicalResult<f32> {
let squared_l2: MV<u32> = SquaredL2::evaluate(x.vector(), y.vector())?;
let bit_scale = inverse_bit_scale::<NBITS>() * inverse_bit_scale::<NBITS>();
let l2 = bit_scale * self.scale_squared * squared_l2.into_inner() as f32;
let result = 1.0 - l2 / 2.0;
Ok(MV::new(result))
}
}
impl<const NBITS: usize>
DistanceFunction<
CompensatedVectorRef<'_, NBITS>,
CompensatedVectorRef<'_, NBITS>,
distances::Result<f32>,
> for CompensatedCosineNormalized
where
Unsigned: Representation<NBITS>,
Self: for<'a, 'b> DistanceFunction<
CompensatedVectorRef<'a, NBITS>,
CompensatedVectorRef<'b, NBITS>,
distances::MathematicalResult<f32>,
>,
{
fn evaluate_similarity(
&self,
x: CompensatedVectorRef<'_, NBITS>,
y: CompensatedVectorRef<'_, NBITS>,
) -> distances::Result<f32> {
let v: MV<f32> = self.evaluate_similarity(x, y)?;
Ok(1.0 - v.into_inner())
}
}
#[cfg(test)]
mod tests {
use diskann_utils::{Reborrow, ReborrowMut};
use rand::{
Rng, SeedableRng,
distr::{Distribution, Uniform},
rngs::StdRng,
};
use super::*;
use crate::{
bits::{Representation, Unsigned},
scalar::bit_scale,
test_util,
};
fn test_compensated_distance<const NBITS: usize, R>(
dim: usize,
ntrials: usize,
max_relative_err_l2: f32,
max_relative_err_ip: f32,
max_relative_err_cos: f32,
max_absolute_error: f32,
rng: &mut R,
) where
Unsigned: Representation<NBITS>,
R: Rng,
CompensatedSquaredL2: for<'a, 'b> DistanceFunction<
CompensatedVectorRef<'a, NBITS>,
CompensatedVectorRef<'b, NBITS>,
distances::MathematicalResult<f32>,
>,
CompensatedSquaredL2: for<'a, 'b> DistanceFunction<
CompensatedVectorRef<'a, NBITS>,
CompensatedVectorRef<'b, NBITS>,
distances::Result<f32>,
>,
CompensatedIP: for<'a, 'b> DistanceFunction<
CompensatedVectorRef<'a, NBITS>,
CompensatedVectorRef<'b, NBITS>,
distances::MathematicalResult<f32>,
>,
CompensatedIP: for<'a, 'b> DistanceFunction<
CompensatedVectorRef<'a, NBITS>,
CompensatedVectorRef<'b, NBITS>,
distances::Result<f32>,
>,
CompensatedCosineNormalized: for<'a, 'b> DistanceFunction<
CompensatedVectorRef<'a, NBITS>,
CompensatedVectorRef<'b, NBITS>,
distances::MathematicalResult<f32>,
>,
CompensatedCosineNormalized: for<'a, 'b> DistanceFunction<
CompensatedVectorRef<'a, NBITS>,
CompensatedVectorRef<'b, NBITS>,
distances::Result<f32>,
>,
{
let alpha_distribution = Uniform::new_inclusive(-16, 16).unwrap();
let beta_distribution = Uniform::new_inclusive(-32, 32).unwrap();
let alpha_divisor: f32 = 64.0;
let beta_divisor: f32 = 128.0;
let domain = Unsigned::domain_const::<NBITS>();
let code_distribution = Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap();
let mut beta: Vec<f32> = vec![0.0; dim];
let mut x_prime: Vec<u8> = vec![0; dim];
let mut y_prime: Vec<u8> = vec![0; dim];
let mut x_reconstructed: Vec<f32> = vec![0.0; dim];
let mut y_reconstructed: Vec<f32> = vec![0.0; dim];
let mut x_compensated = CompensatedVector::<NBITS>::new_boxed(dim);
let mut y_compensated = CompensatedVector::<NBITS>::new_boxed(dim);
let populate_compensation = |mut dst: MutCompensatedVectorRef<'_, NBITS>,
codes: &[u8],
alpha: f32,
beta: &[f32]| {
assert_eq!(dst.len(), codes.len());
assert_eq!(dst.len(), beta.len());
let mut compensation: f32 = 0.0;
let mut vector = dst.vector_mut();
for (i, (&c, &b)) in std::iter::zip(codes.iter(), beta.iter()).enumerate() {
vector.set(i, c.into()).unwrap();
let c: f32 = c.into();
compensation += c * b;
}
dst.set_meta(Compensation(alpha * compensation / bit_scale::<NBITS>()));
};
for trial in 0..ntrials {
let alpha = (alpha_distribution.sample(rng) as f32) / alpha_divisor;
beta.iter_mut().for_each(|b| {
*b = (beta_distribution.sample(rng) as f32) / beta_divisor;
});
x_prime
.iter_mut()
.for_each(|x| *x = code_distribution.sample(rng).try_into().unwrap());
y_prime
.iter_mut()
.for_each(|y| *y = code_distribution.sample(rng).try_into().unwrap());
let bit_scale = inverse_bit_scale::<NBITS>();
x_reconstructed
.iter_mut()
.zip(x_prime.iter())
.zip(beta.iter())
.for_each(|((x, xp), b)| {
*x = (alpha * *xp as f32) * bit_scale + *b;
});
y_reconstructed
.iter_mut()
.zip(y_prime.iter())
.zip(beta.iter())
.for_each(|((y, yp), b)| {
*y = (alpha * *yp as f32) * bit_scale + *b;
});
populate_compensation(x_compensated.reborrow_mut(), &x_prime, alpha, &beta);
populate_compensation(y_compensated.reborrow_mut(), &y_prime, alpha, &beta);
let expected: MV<f32> =
diskann_vector::distance::SquaredL2::evaluate(&*x_reconstructed, &*y_reconstructed);
let distance = CompensatedSquaredL2::new(alpha * alpha);
let got: distances::MathematicalResult<f32> =
distance.evaluate_similarity(x_compensated.reborrow(), y_compensated.reborrow());
let got = got.unwrap();
let relative_err =
test_util::compute_relative_error(got.into_inner(), expected.into_inner());
let absolute_err =
test_util::compute_absolute_error(got.into_inner(), expected.into_inner());
assert!(
relative_err <= max_relative_err_l2 || absolute_err <= max_absolute_error,
"failed SquaredL2 for NBITS = {}, dim = {}, trial = {}. \
Got an error {} (rel) / {} (abs) with tolerance {}/{}. \
Expected {}, got {}",
NBITS,
dim,
trial,
relative_err,
absolute_err,
max_relative_err_l2,
max_absolute_error,
expected.into_inner(),
got.into_inner(),
);
let got_f32: distances::Result<f32> =
distance.evaluate_similarity(x_compensated.reborrow(), y_compensated.reborrow());
let got_f32 = got_f32.unwrap();
assert_eq!(got.into_inner(), got_f32);
let expected: MV<f32> = diskann_vector::distance::InnerProduct::evaluate(
&*x_reconstructed,
&*y_reconstructed,
);
let distance =
CompensatedIP::new(alpha * alpha, beta.iter().map(|&i| i * i).sum::<f32>());
let got: distances::MathematicalResult<f32> =
distance.evaluate_similarity(x_compensated.reborrow(), y_compensated.reborrow());
let got = got.unwrap();
let relative_err =
test_util::compute_relative_error(got.into_inner(), expected.into_inner());
let absolute_err =
test_util::compute_absolute_error(got.into_inner(), expected.into_inner());
assert!(
relative_err <= max_relative_err_ip || absolute_err < max_absolute_error,
"failed InnerProduct for NBITS = {}, dim = {}, trial = {}. \
Got an error {} (rel) / {} (abs) with tolerance {}/{}. \
Expected {}, got {}",
NBITS,
dim,
trial,
relative_err,
absolute_err,
max_relative_err_ip,
max_absolute_error,
expected.into_inner(),
got.into_inner(),
);
let got_f32: distances::Result<f32> =
distance.evaluate_similarity(x_compensated.reborrow(), y_compensated.reborrow());
let got_f32 = got_f32.unwrap();
assert_eq!(-got.into_inner(), got_f32);
let expected: MV<f32> =
diskann_vector::distance::SquaredL2::evaluate(&*x_reconstructed, &*y_reconstructed);
let expected = 1.0 - expected.into_inner() / 2.0;
let distance = CompensatedCosineNormalized::new(alpha * alpha);
let got: distances::MathematicalResult<f32> =
distance.evaluate_similarity(x_compensated.reborrow(), y_compensated.reborrow());
let got = got.unwrap();
if expected != 0.0 {
let relative_err = test_util::compute_relative_error(got.into_inner(), expected);
let absolute_err = test_util::compute_absolute_error(got.into_inner(), expected);
assert!(
relative_err < max_relative_err_cos || absolute_err < max_absolute_error,
"failed CosineNormalized for NBITS = {}, dim = {}, trial = {}. \
Got an error {} (rel) / {} (abs) with tolerance {}/{}. \
Expected {}, got {}",
NBITS,
dim,
trial,
relative_err,
absolute_err,
max_relative_err_cos,
max_absolute_error,
expected,
got.into_inner(),
);
} else {
let absolute_err = test_util::compute_absolute_error(got.into_inner(), expected);
assert!(
absolute_err < max_absolute_error,
"failed CosineNormalized for NBITS = {}, dim = {}, trial = {}. \
Got an absolute error {} with tolerance {}. \
Expected {}, got {}",
NBITS,
dim,
trial,
absolute_err,
max_absolute_error,
expected,
got.into_inner(),
);
}
let got_f32: distances::Result<f32> =
distance.evaluate_similarity(x_compensated.reborrow(), y_compensated.reborrow());
let got_f32 = got_f32.unwrap();
assert_eq!(1.0 - got.into_inner(), got_f32);
}
}
cfg_if::cfg_if! {
if #[cfg(miri)] {
const MAX_DIM: usize = 37;
const TRIALS_PER_DIM: usize = 1;
} else {
const MAX_DIM: usize = 256;
const TRIALS_PER_DIM: usize = 20;
}
}
macro_rules! test_unsigned_compensated {
(
$name:ident,
$nbits:literal,
$relative_err_l2:literal,
$relative_err_ip:literal,
$relative_err_cos:literal,
$seed:literal
) => {
#[test]
fn $name() {
let mut rng = StdRng::seed_from_u64($seed);
let absolute_error: f32 = 2.0e-7;
for dim in 0..MAX_DIM {
test_compensated_distance::<$nbits, _>(
dim,
TRIALS_PER_DIM,
$relative_err_l2,
$relative_err_ip,
$relative_err_cos,
absolute_error,
&mut rng,
);
}
}
};
}
test_unsigned_compensated!(
unsigned_compensated_distances_8bit,
8,
4.0e-4,
3.0e-6,
1.0e-3,
0xa32d5658097a1c35
);
test_unsigned_compensated!(
unsigned_compensated_distances_7bit,
7,
5.0e-6,
3.0e-6,
1.0e-3,
0x0b65ca44ec7b47d8
);
test_unsigned_compensated!(
unsigned_compensated_distances_6bit,
6,
5.0e-6,
3.0e-6,
1.0e-3,
0x471b640fba5c520b
);
test_unsigned_compensated!(
unsigned_compensated_distances_5bit,
5,
5.0e-6,
3.0e-6,
1.0e-3,
0xf60c0c8d1aadc126
);
test_unsigned_compensated!(
unsigned_compensated_distances_4bit,
4,
3.0e-6,
3.0e-6,
1.0e-3,
0xcc2b897373a143f3
);
test_unsigned_compensated!(
unsigned_compensated_distances_3bit,
3,
3.0e-6,
3.0e-6,
1.0e-3,
0xaedf3d2a223b7b77
);
test_unsigned_compensated!(
unsigned_compensated_distances_2bit,
2,
3.0e-6,
3.0e-6,
1.0e-3,
0x2b34015910b34083
);
test_unsigned_compensated!(
unsigned_compensated_distances_1bit,
1,
0.0,
0.0,
0.0,
0x09fa14c42a9d7d98
);
}