use diskann_vector::{MathematicalValue, PureDistanceFunction};
use thiserror::Error;
use crate::{
alloc::GlobalAllocator,
bits::{BitSlice, Dense, Representation, Unsigned},
distances,
distances::{InnerProduct, MV},
meta::{self, slice},
};
#[derive(Default, Debug, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
pub struct MinMaxCompensation {
pub dim: u32, pub b: f32, pub n: f32, pub a: f32, pub norm_squared: f32, }
const META_BYTES: usize = std::mem::size_of::<MinMaxCompensation>();
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum MetaParseError {
#[error("Invalid size: {0}, must contain at least {META_BYTES} bytes")]
NotCanonical(usize),
}
impl MinMaxCompensation {
#[inline(always)]
pub fn read_dimension(bytes: &[u8]) -> Result<usize, MetaParseError> {
if bytes.len() < META_BYTES {
return Err(MetaParseError::NotCanonical(bytes.len()));
}
let dim_bytes: [u8; 4] = bytes.get(..4).map_or_else(
|| Err(MetaParseError::NotCanonical(bytes.len())),
|slice| {
slice
.try_into()
.map_err(|_| MetaParseError::NotCanonical(bytes.len()))
},
)?;
let dim = u32::from_le_bytes(dim_bytes) as usize;
Ok(dim)
}
}
pub type Data<const NBITS: usize> = meta::Vector<NBITS, Unsigned, MinMaxCompensation, Dense>;
pub type DataRef<'a, const NBITS: usize> =
meta::VectorRef<'a, NBITS, Unsigned, MinMaxCompensation, Dense>;
#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
pub enum DecompressError {
#[error("expected src and dst length to be identical, instead src is {0}, and dst is {1}")]
LengthMismatch(usize, usize),
}
impl<const NBITS: usize> DataRef<'_, NBITS>
where
Unsigned: Representation<NBITS>,
{
pub fn decompress_into(&self, dst: &mut [f32]) -> Result<(), DecompressError> {
if dst.len() != self.len() {
return Err(DecompressError::LengthMismatch(self.len(), dst.len()));
}
let meta = self.meta();
dst.iter_mut().enumerate().for_each(|(i, d)| unsafe {
*d = self.vector().get_unchecked(i) as f32 * meta.a + meta.b
});
Ok(())
}
}
pub type DataMutRef<'a, const NBITS: usize> =
meta::VectorMut<'a, NBITS, Unsigned, MinMaxCompensation, Dense>;
#[derive(Debug, Clone, Copy, Default, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
pub struct FullQueryMeta {
pub sum: f32,
pub norm_squared: f32,
}
pub type FullQuery<A = GlobalAllocator> = slice::PolySlice<f32, FullQueryMeta, A>;
pub type FullQueryRef<'a> = slice::SliceRef<'a, f32, FullQueryMeta>;
pub type FullQueryMut<'a> = slice::SliceMut<'a, f32, FullQueryMeta>;
#[inline(always)]
fn kernel<const NBITS: usize, const MBITS: usize, F>(
x: DataRef<'_, NBITS>,
y: DataRef<'_, MBITS>,
f: F,
) -> distances::MathematicalResult<f32>
where
Unsigned: Representation<NBITS> + Representation<MBITS>,
InnerProduct: for<'a, 'b> PureDistanceFunction<
BitSlice<'a, NBITS, Unsigned>,
BitSlice<'b, MBITS, Unsigned>,
distances::MathematicalResult<u32>,
>,
F: Fn(f32, &MinMaxCompensation, &MinMaxCompensation) -> f32,
{
let raw_product = InnerProduct::evaluate(x.vector(), y.vector())?;
let (xm, ym) = (x.meta(), y.meta());
let term0 = xm.a * ym.a * raw_product.into_inner() as f32;
let term1_x = xm.n * ym.b;
let term1_y = ym.n * xm.b;
let term2 = xm.b * ym.b * (x.len() as f32);
let v = term0 + term1_x + term1_y + term2;
Ok(MV::new(f(v, &xm, &ym)))
}
pub struct MinMaxIP;
impl<const NBITS: usize, const MBITS: usize>
PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::MathematicalResult<f32>>
for MinMaxIP
where
Unsigned: Representation<NBITS> + Representation<MBITS>,
InnerProduct: for<'a, 'b> PureDistanceFunction<
BitSlice<'a, NBITS, Unsigned>,
BitSlice<'b, MBITS, Unsigned>,
distances::MathematicalResult<u32>,
>,
{
#[inline(always)]
fn evaluate(
x: DataRef<'_, NBITS>,
y: DataRef<'_, MBITS>,
) -> distances::MathematicalResult<f32> {
kernel(x, y, |v, _, _| v)
}
}
impl<const NBITS: usize, const MBITS: usize>
PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::Result<f32>>
for MinMaxIP
where
Unsigned: Representation<NBITS> + Representation<MBITS>,
InnerProduct: for<'a, 'b> PureDistanceFunction<
BitSlice<'a, NBITS, Unsigned>,
BitSlice<'b, MBITS, Unsigned>,
distances::MathematicalResult<u32>,
>,
{
#[inline(always)]
fn evaluate(x: DataRef<'_, NBITS>, y: DataRef<'_, MBITS>) -> distances::Result<f32> {
let v: distances::MathematicalResult<f32> = Self::evaluate(x, y);
Ok(-v?.into_inner())
}
}
impl<const NBITS: usize>
PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::MathematicalResult<f32>>
for MinMaxIP
where
Unsigned: Representation<NBITS>,
InnerProduct: for<'a, 'b> PureDistanceFunction<
&'a [f32],
BitSlice<'b, NBITS, Unsigned>,
distances::MathematicalResult<f32>,
>,
{
#[inline(always)]
fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::MathematicalResult<f32> {
let raw_product: f32 = InnerProduct::evaluate(x.vector(), y.vector())?.into_inner();
Ok(MathematicalValue::new(
raw_product * y.meta().a + x.meta().sum * y.meta().b,
))
}
}
impl<const NBITS: usize>
PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::Result<f32>> for MinMaxIP
where
Unsigned: Representation<NBITS>,
InnerProduct: for<'a, 'b> PureDistanceFunction<
&'a [f32],
BitSlice<'b, NBITS, Unsigned>,
distances::MathematicalResult<f32>,
>,
{
#[inline(always)]
fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
let v: distances::MathematicalResult<f32> = Self::evaluate(x, y);
Ok(-v?.into_inner())
}
}
pub struct MinMaxL2Squared;
impl<const NBITS: usize, const MBITS: usize>
PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::MathematicalResult<f32>>
for MinMaxL2Squared
where
Unsigned: Representation<NBITS> + Representation<MBITS>,
InnerProduct: for<'a, 'b> PureDistanceFunction<
BitSlice<'a, NBITS, Unsigned>,
BitSlice<'b, MBITS, Unsigned>,
distances::MathematicalResult<u32>,
>,
{
#[inline(always)]
fn evaluate(
x: DataRef<'_, NBITS>,
y: DataRef<'_, MBITS>,
) -> distances::MathematicalResult<f32> {
kernel(x, y, |v, xm, ym| {
-2.0 * v + xm.norm_squared + ym.norm_squared
})
}
}
impl<const NBITS: usize, const MBITS: usize>
PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::Result<f32>>
for MinMaxL2Squared
where
Unsigned: Representation<NBITS> + Representation<MBITS>,
InnerProduct: for<'a, 'b> PureDistanceFunction<
BitSlice<'a, NBITS, Unsigned>,
BitSlice<'b, MBITS, Unsigned>,
distances::MathematicalResult<u32>,
>,
{
#[inline(always)]
fn evaluate(x: DataRef<'_, NBITS>, y: DataRef<'_, MBITS>) -> distances::Result<f32> {
let v: distances::MathematicalResult<f32> = Self::evaluate(x, y);
Ok(v?.into_inner())
}
}
impl<const NBITS: usize>
PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::MathematicalResult<f32>>
for MinMaxL2Squared
where
Unsigned: Representation<NBITS>,
InnerProduct: for<'a, 'b> PureDistanceFunction<
&'a [f32],
BitSlice<'b, NBITS, Unsigned>,
distances::MathematicalResult<f32>,
>,
{
#[inline(always)]
fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::MathematicalResult<f32> {
let raw_product = InnerProduct::evaluate(x.vector(), y.vector())?.into_inner();
let ym = y.meta();
let compensated_ip = raw_product * ym.a + x.meta().sum * ym.b;
Ok(MV::new(
x.meta().norm_squared + ym.norm_squared - 2.0 * compensated_ip,
))
}
}
impl<const NBITS: usize>
PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::Result<f32>>
for MinMaxL2Squared
where
Unsigned: Representation<NBITS>,
InnerProduct: for<'a, 'b> PureDistanceFunction<
&'a [f32],
BitSlice<'b, NBITS, Unsigned>,
distances::MathematicalResult<f32>,
>,
{
#[inline(always)]
fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
let v: distances::MathematicalResult<f32> = Self::evaluate(x, y);
Ok(v?.into_inner())
}
}
pub struct MinMaxCosine;
impl<const NBITS: usize, const MBITS: usize>
PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::Result<f32>>
for MinMaxCosine
where
Unsigned: Representation<NBITS> + Representation<MBITS>,
MinMaxIP: for<'a, 'b> PureDistanceFunction<
DataRef<'a, NBITS>,
DataRef<'b, MBITS>,
distances::MathematicalResult<f32>,
>,
{
#[inline(always)]
fn evaluate(x: DataRef<'_, NBITS>, y: DataRef<'_, MBITS>) -> distances::Result<f32> {
let ip: MV<f32> = MinMaxIP::evaluate(x, y)?;
let (xm, ym) = (x.meta(), y.meta());
Ok(1.0 - ip.into_inner() / (xm.norm_squared.sqrt() * ym.norm_squared.sqrt()))
}
}
impl<const NBITS: usize>
PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::Result<f32>>
for MinMaxCosine
where
Unsigned: Representation<NBITS>,
MinMaxIP: for<'a, 'b> PureDistanceFunction<
FullQueryRef<'a>,
DataRef<'b, NBITS>,
distances::MathematicalResult<f32>,
>,
{
#[inline(always)]
fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
let ip: MathematicalValue<f32> = MinMaxIP::evaluate(x, y)?;
let (xm, ym) = (x.meta().norm_squared, y.meta());
Ok(1.0 - ip.into_inner() / (xm.sqrt() * ym.norm_squared.sqrt()))
}
}
pub struct MinMaxCosineNormalized;
impl<const NBITS: usize, const MBITS: usize>
PureDistanceFunction<DataRef<'_, NBITS>, DataRef<'_, MBITS>, distances::Result<f32>>
for MinMaxCosineNormalized
where
Unsigned: Representation<NBITS> + Representation<MBITS>,
MinMaxIP: for<'a, 'b> PureDistanceFunction<
DataRef<'a, NBITS>,
DataRef<'b, MBITS>,
distances::MathematicalResult<f32>,
>,
{
#[inline(always)]
fn evaluate(x: DataRef<'_, NBITS>, y: DataRef<'_, MBITS>) -> distances::Result<f32> {
let ip: MathematicalValue<f32> = MinMaxIP::evaluate(x, y)?;
Ok(1.0 - ip.into_inner()) }
}
impl<const NBITS: usize>
PureDistanceFunction<FullQueryRef<'_>, DataRef<'_, NBITS>, distances::Result<f32>>
for MinMaxCosineNormalized
where
Unsigned: Representation<NBITS>,
MinMaxIP: for<'a, 'b> PureDistanceFunction<
FullQueryRef<'a>,
DataRef<'b, NBITS>,
distances::MathematicalResult<f32>,
>,
{
#[inline(always)]
fn evaluate(x: FullQueryRef<'_>, y: DataRef<'_, NBITS>) -> distances::Result<f32> {
let ip: MathematicalValue<f32> = MinMaxIP::evaluate(x, y)?;
Ok(1.0 - ip.into_inner()) }
}
#[cfg(test)]
#[cfg(not(miri))]
mod minmax_vector_tests {
use diskann_utils::Reborrow;
use rand::{
Rng, SeedableRng,
distr::{Distribution, Uniform},
rngs::StdRng,
};
use super::*;
use crate::{alloc::GlobalAllocator, scalar::bit_scale};
fn random_minmax_vector<const NBITS: usize>(
dim: usize,
rng: &mut impl Rng,
) -> (Data<NBITS>, Vec<f32>)
where
Unsigned: Representation<NBITS>,
{
let mut v = Data::<NBITS>::new_boxed(dim);
let domain = Unsigned::domain_const::<NBITS>();
let code_dist = Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap();
{
let mut bs = v.vector_mut();
for i in 0..dim {
bs.set(i, code_dist.sample(rng)).unwrap();
}
}
let a: f32 = Uniform::new_inclusive(0.0, 2.0).unwrap().sample(rng);
let b: f32 = Uniform::new_inclusive(0.0, 2.0).unwrap().sample(rng);
let original: Vec<f32> = (0..dim)
.map(|i| a * v.vector().get(i).unwrap() as f32 + b)
.collect();
let code_sum: f32 = (0..dim).map(|i| v.vector().get(i).unwrap() as f32).sum();
let norm_squared: f32 = original.iter().map(|x| x * x).sum();
v.set_meta(MinMaxCompensation {
a,
b,
n: a * code_sum,
norm_squared,
dim: dim as u32,
});
(v, original)
}
fn test_minmax_compensated_vectors<const NBITS: usize, R>(dim: usize, rng: &mut R)
where
Unsigned: Representation<NBITS>,
InnerProduct: for<'a, 'b> PureDistanceFunction<
BitSlice<'a, NBITS, Unsigned>,
BitSlice<'b, NBITS, Unsigned>,
distances::MathematicalResult<u32>,
>,
InnerProduct: for<'a, 'b> PureDistanceFunction<
&'a [f32],
BitSlice<'b, NBITS, Unsigned>,
distances::MathematicalResult<f32>,
>,
R: Rng,
{
assert!(dim <= bit_scale::<NBITS>() as usize);
let (v1, original1) = random_minmax_vector::<NBITS>(dim, rng);
let (v2, original2) = random_minmax_vector::<NBITS>(dim, rng);
let norm1_squared = v1.meta().norm_squared;
let norm2_squared = v2.meta().norm_squared;
let expected_ip = (0..dim).map(|i| original1[i] * original2[i]).sum::<f32>();
let computed_ip_f32: distances::Result<f32> =
MinMaxIP::evaluate(v1.reborrow(), v2.reborrow());
let computed_ip_f32 = computed_ip_f32.unwrap();
assert!(
(expected_ip - (-computed_ip_f32)).abs() / expected_ip.abs() < 1e-3,
"Inner product (f32) failed: expected {}, got {} on dim : {}",
-expected_ip,
computed_ip_f32,
dim
);
let expected_l2 = (0..dim)
.map(|i| original1[i] - original2[i])
.map(|x| x.powf(2.0))
.sum::<f32>();
let computed_l2_f32: distances::Result<f32> =
MinMaxL2Squared::evaluate(v1.reborrow(), v2.reborrow());
let computed_l2_f32 = computed_l2_f32.unwrap();
assert!(
((computed_l2_f32 - expected_l2).abs() / expected_l2) < 1e-3,
"L2 distance (f32) failed: expected {}, got {} on dim : {}",
expected_l2,
computed_l2_f32,
dim
);
let expected_cosine = 1.0 - expected_ip / (norm1_squared.sqrt() * norm2_squared.sqrt());
let computed_cosine: distances::Result<f32> =
MinMaxCosine::evaluate(v1.reborrow(), v2.reborrow());
let computed_cosine = computed_cosine.unwrap();
{
let passed = (computed_cosine - expected_cosine).abs() < 1e-6
|| ((computed_cosine - expected_cosine).abs() / expected_cosine) < 1e-3;
assert!(
passed,
"Cosine distance (f32) failed: expected {}, got {} on dim : {}",
expected_cosine, computed_cosine, dim
);
}
let cosine_normalized: distances::Result<f32> =
MinMaxCosineNormalized::evaluate(v1.reborrow(), v2.reborrow());
let cosine_normalized = cosine_normalized.unwrap();
let expected_cos_normalized = 1.0 - expected_ip;
assert!(
((expected_cos_normalized - cosine_normalized).abs() / expected_cos_normalized.abs())
< 1e-6,
"CosineNormalized distance (f32) failed: expected {}, got {} on dim : {}",
expected_cos_normalized,
cosine_normalized,
dim
);
let mut fp_query = FullQuery::new_in(dim, GlobalAllocator).unwrap();
fp_query.vector_mut().copy_from_slice(&original1);
*fp_query.meta_mut() = FullQueryMeta {
norm_squared: norm1_squared,
sum: original1.iter().sum::<f32>(),
};
let fp_ip: distances::Result<f32> = MinMaxIP::evaluate(fp_query.reborrow(), v2.reborrow());
let fp_ip = fp_ip.unwrap();
assert!(
(expected_ip - (-fp_ip)).abs() / expected_ip.abs() < 1e-3,
"Inner product (f32) failed: expected {}, got {} on dim : {}",
-expected_ip,
fp_ip,
dim
);
let fp_l2: distances::Result<f32> =
MinMaxL2Squared::evaluate(fp_query.reborrow(), v2.reborrow());
let fp_l2 = fp_l2.unwrap();
assert!(
((fp_l2 - expected_l2).abs() / expected_l2) < 1e-3,
"L2 distance (f32) failed: expected {}, got {} on dim : {}",
expected_l2,
computed_l2_f32,
dim
);
let fp_cosine: distances::Result<f32> =
MinMaxCosine::evaluate(fp_query.reborrow(), v2.reborrow());
let fp_cosine = fp_cosine.unwrap();
let diff = (fp_cosine - expected_cosine).abs();
assert!(
(diff / expected_cosine) < 1e-3 || diff <= 1e-6,
"Cosine distance (f32) failed: expected {}, got {} on dim : {}",
expected_cosine,
fp_cosine,
dim
);
let fp_cos_norm: distances::Result<f32> =
MinMaxCosineNormalized::evaluate(fp_query.reborrow(), v2.reborrow());
let fp_cos_norm = fp_cos_norm.unwrap();
assert!(
(((1.0 - expected_ip) - fp_cos_norm).abs() / (1.0 - expected_ip)) < 1e-3,
"Cosine distance (f32) failed: expected {}, got {} on dim : {}",
(1.0 - expected_ip),
fp_cos_norm,
dim
);
let meta = v1.meta();
let v1_ref = DataRef::new(v1.vector(), &meta);
let dim = v1_ref.len();
let mut boxed = vec![0f32; dim + 1];
let pre = v1_ref.decompress_into(&mut boxed);
assert_eq!(
pre.unwrap_err(),
DecompressError::LengthMismatch(dim, dim + 1)
);
let pre = v1_ref.decompress_into(&mut boxed[..dim - 1]);
assert_eq!(
pre.unwrap_err(),
DecompressError::LengthMismatch(dim, dim - 1)
);
let pre = v1_ref.decompress_into(&mut boxed[..dim]);
assert!(pre.is_ok());
boxed
.iter()
.zip(original1.iter())
.for_each(|(x, y)| assert!((*x - *y).abs() <= 1e-6));
let mut bytes = vec![0u8; Data::canonical_bytes(dim)];
let mut data = DataMutRef::from_canonical_front_mut(bytes.as_mut_slice(), dim).unwrap();
data.set_meta(meta);
let pre = MinMaxCompensation::read_dimension(&bytes);
assert!(pre.is_ok());
let read_dim = pre.unwrap();
assert_eq!(read_dim, dim);
let pre = MinMaxCompensation::read_dimension(&[0_u8; 2]);
assert_eq!(pre.unwrap_err(), MetaParseError::NotCanonical(2));
}
cfg_if::cfg_if! {
if #[cfg(miri)] {
const TRIALS: usize = 2;
} else {
const TRIALS: usize = 10;
}
}
macro_rules! test_minmax_compensated {
($name:ident, $nbits:literal, $seed:literal) => {
#[test]
fn $name() {
let mut rng = StdRng::seed_from_u64($seed);
const MAX_DIM: usize = (bit_scale::<$nbits>() as usize);
for dim in 1..=MAX_DIM {
for _ in 0..TRIALS {
test_minmax_compensated_vectors::<$nbits, _>(dim, &mut rng);
}
}
}
};
}
test_minmax_compensated!(unsigned_minmax_compensated_test_u1, 1, 0xa33d5658097a1c35);
test_minmax_compensated!(unsigned_minmax_compensated_test_u2, 2, 0xaedf3d2a223b7b77);
test_minmax_compensated!(unsigned_minmax_compensated_test_u4, 4, 0xf60c0c8d1aadc126);
test_minmax_compensated!(unsigned_minmax_compensated_test_u8, 8, 0x09fa14c42a9d7d98);
fn test_minmax_heterogeneous_kernel<const NBITS: usize, const MBITS: usize, R>(
dim: usize,
rng: &mut R,
) where
Unsigned: Representation<NBITS> + Representation<MBITS>,
InnerProduct: for<'a, 'b> PureDistanceFunction<
BitSlice<'a, NBITS, Unsigned>,
BitSlice<'b, MBITS, Unsigned>,
distances::MathematicalResult<u32>,
>,
R: Rng,
{
let (v_query, original1) = random_minmax_vector::<NBITS>(dim, rng);
let (v_data, original2) = random_minmax_vector::<MBITS>(dim, rng);
let expected_ip: f32 = original1.iter().zip(&original2).map(|(x, y)| x * y).sum();
let computed_ip = kernel(v_query.reborrow(), v_data.reborrow(), |v, _, _| v)
.unwrap()
.into_inner();
assert!(
(expected_ip - computed_ip).abs() / expected_ip.abs().max(1e-10) < 1e-6,
"Heterogeneous IP ({},{}) failed: expected {}, got {} on dim: {}",
NBITS,
MBITS,
expected_ip,
computed_ip,
dim,
);
}
macro_rules! test_minmax_heterogeneous {
($name:ident, $N:literal, $M:literal, $seed:literal) => {
#[test]
fn $name() {
let mut rng = StdRng::seed_from_u64($seed);
const MAX_DIM: usize = bit_scale::<$M>() as usize;
for dim in 1..=MAX_DIM {
for _ in 0..TRIALS {
test_minmax_heterogeneous_kernel::<$N, $M, _>(dim, &mut rng);
}
}
}
};
}
test_minmax_heterogeneous!(minmax_heterogeneous_8x4, 8, 4, 0xb7c3d9e5f1a20864);
test_minmax_heterogeneous!(minmax_heterogeneous_8x2, 8, 2, 0x4e8f2c6a1d3b5079);
test_minmax_heterogeneous!(minmax_heterogeneous_8x1, 8, 1, 0x1b0f2c614d2a7141);
}