use diskann::{ANNError, utils::VectorRepr};
use diskann_quantization::{
bits::{BitSlice, Representation, Unsigned},
distances::InnerProduct,
meta::NotCanonical,
minmax::{
Data, DataRef, DecompressError, MetaParseError, MinMaxCompensation, MinMaxCosine,
MinMaxCosineNormalized, MinMaxIP, MinMaxL2Squared,
},
};
use diskann_vector::{PureDistanceFunction, distance::Metric};
use thiserror::Error;
#[derive(Debug, Error, Clone, PartialEq)]
pub enum MMConvertError {
#[error("MinMax metadata cannot be parsed, with error {0}")]
MetaParseError(#[from] MetaParseError),
#[error("Data format is not canonical {0}")]
NotCanonical(#[from] NotCanonical),
#[error("Decompression failed {0}")]
Decompression(#[from] DecompressError),
#[error("Full-precision slice length {0} does not match destination slice length {1}.")]
WrongLength(usize, usize),
}
impl From<MMConvertError> for ANNError {
fn from(value: MMConvertError) -> Self {
ANNError::log_index_error(format_args!(
"Unable to convert MinMaxElement slice, error : {:?}",
value
))
}
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, bytemuck::Pod, bytemuck::Zeroable)]
#[repr(transparent)]
pub struct MinMaxElement<const NBITS: usize>(u8);
pub type MinMax8 = MinMaxElement<8>;
pub type MinMax4 = MinMaxElement<4>;
pub type _MinMax2 = MinMaxElement<2>;
pub type _MinMax1 = MinMaxElement<1>;
impl<const NBITS: usize> MinMaxElement<NBITS> {
const UPPER_RANGE: usize = (0x1 << NBITS) - 1;
}
macro_rules! impl_from_primitive {
($NBITS:expr, $method:ident, $input_type:ty) => {
fn $method(n: $input_type) -> Option<Self> {
if n <= Self::UPPER_RANGE as $input_type {
Some(MinMaxElement::<$NBITS>(n as u8))
} else {
None
}
}
};
($NBITS:expr, $method:ident, $input_type:ty, signed) => {
fn $method(n: $input_type) -> Option<Self> {
if (0..=Self::UPPER_RANGE as $input_type).contains(&n) {
Some(MinMaxElement::<$NBITS>(n as u8))
} else {
None
}
}
};
}
impl<const NBITS: usize> num_traits::FromPrimitive for MinMaxElement<NBITS> {
impl_from_primitive!(NBITS, from_u8, u8);
impl_from_primitive!(NBITS, from_i64, i64, signed);
impl_from_primitive!(NBITS, from_u64, u64);
impl_from_primitive!(NBITS, from_i32, i32, signed);
impl_from_primitive!(NBITS, from_u32, u32);
}
fn as_fn_pointer_minmax<T, const NBITS: usize>(
x: &[MinMaxElement<NBITS>],
y: &[MinMaxElement<NBITS>],
) -> f32
where
T: for<'a, 'b> PureDistanceFunction<
DataRef<'a, NBITS>,
DataRef<'b, NBITS>,
diskann_quantization::distances::Result<f32>,
>,
Unsigned: Representation<NBITS>,
{
#[allow(clippy::unwrap_used)]
let xref = MinMaxElement::<NBITS>::from_raw(x).unwrap();
#[allow(clippy::unwrap_used)]
let yref = MinMaxElement::<NBITS>::from_raw(y).unwrap();
#[allow(clippy::unwrap_used)]
T::evaluate(xref, yref).unwrap()
}
impl<const NBITS: usize> VectorRepr for MinMaxElement<NBITS>
where
InnerProduct: for<'a, 'b> PureDistanceFunction<
BitSlice<'a, NBITS, Unsigned>,
BitSlice<'b, NBITS, Unsigned>,
diskann_quantization::distances::MathematicalResult<u32>,
>,
Unsigned: Representation<NBITS>,
{
type Error = MMConvertError;
type Distance = FnPtr<Self>;
type QueryDistance = BufferedFnPtr<Self>;
fn distance(metric: Metric, _dim: Option<usize>) -> Self::Distance {
FnPtr::new(Self::distance_comparer(metric))
}
fn query_distance(query: &[Self], metric: Metric) -> Self::QueryDistance {
BufferedFnPtr {
query: query.into(),
f: Self::distance_comparer(metric),
}
}
fn full_dimension(vec: &[Self]) -> Result<usize, Self::Error> {
Self::extract_dimension(vec)
}
fn as_f32(data: &[Self]) -> Result<impl std::ops::Deref<Target = [f32]>, Self::Error> {
let data_ref = Self::from_raw(data)?;
let dim = data_ref.meta().dim as usize;
let mut converted: Vec<f32> = (0..dim).map(|_| f32::default()).collect();
data_ref.decompress_into(&mut converted)?;
Ok(converted)
}
fn as_f32_into(src: &[Self], dst: &mut [f32]) -> Result<(), Self::Error> {
let data_ref = Self::from_raw(src)?;
let dim = data_ref.meta().dim as usize;
if dim != dst.len() {
return Err(MMConvertError::WrongLength(dim, dst.len()));
}
data_ref.decompress_into(dst)?;
Ok(())
}
}
impl<const NBITS: usize> MinMaxElement<NBITS>
where
Unsigned: Representation<NBITS>,
{
fn from_raw(raw: &[Self]) -> Result<DataRef<'_, NBITS>, MMConvertError> {
let dim = Self::extract_dimension(raw)?;
let bytes = bytemuck::cast_slice::<MinMaxElement<NBITS>, u8>(raw);
let count = Data::<NBITS>::canonical_bytes(dim);
if raw.len() < count {
Err(MMConvertError::NotCanonical(
diskann_quantization::meta::NotCanonical::WrongLength(raw.len(), count),
))
} else {
DataRef::<'_, NBITS>::from_canonical_front(&bytes[..count], dim).map_err(|x| x.into())
}
}
fn extract_dimension(raw: &[Self]) -> Result<usize, MMConvertError> {
let bytes = bytemuck::cast_slice::<Self, u8>(raw);
let dim = MinMaxCompensation::read_dimension(bytes)?;
Ok(dim)
}
fn distance_comparer(metric: Metric) -> fn(&[Self], &[Self]) -> f32
where
InnerProduct: for<'a, 'b> PureDistanceFunction<
BitSlice<'a, NBITS, Unsigned>,
BitSlice<'b, NBITS, Unsigned>,
diskann_quantization::distances::MathematicalResult<u32>,
>,
{
match metric {
Metric::Cosine => as_fn_pointer_minmax::<MinMaxCosine, NBITS>,
Metric::InnerProduct => as_fn_pointer_minmax::<MinMaxIP, NBITS>,
Metric::L2 => as_fn_pointer_minmax::<MinMaxL2Squared, NBITS>,
Metric::CosineNormalized => as_fn_pointer_minmax::<MinMaxCosineNormalized, NBITS>,
}
}
}
#[derive(Debug)]
pub struct FnPtr<T>(fn(&[T], &[T]) -> f32);
impl<T> FnPtr<T> {
pub fn new(f: fn(&[T], &[T]) -> f32) -> Self {
Self(f)
}
}
impl<T> diskann_vector::DistanceFunction<&[T], &[T], f32> for FnPtr<T> {
#[inline(always)]
fn evaluate_similarity(&self, x: &[T], y: &[T]) -> f32 {
(self.0)(x, y)
}
}
#[derive(Debug)]
pub struct BufferedFnPtr<T> {
query: Box<[T]>,
f: fn(&[T], &[T]) -> f32,
}
impl<T> diskann_vector::PreprocessedDistanceFunction<&[T]> for BufferedFnPtr<T> {
#[inline(always)]
fn evaluate_similarity(&self, x: &[T]) -> f32 {
(self.f)(&self.query, x)
}
}
#[cfg(test)]
mod tests {
use std::num::NonZeroUsize;
use crate::utils::create_rnd_from_seed_in_tests;
use diskann_quantization::{
CompressInto,
algorithms::{Transform, transforms::NullTransform},
minmax::{DataMutRef, DataRef, MinMaxQuantizer},
num::Positive,
};
use diskann_utils::ReborrowMut;
use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction, distance::Metric};
use num_traits::FromPrimitive;
use rand::rngs::StdRng;
use rand_distr::{Distribution, Uniform};
use super::*;
macro_rules! expand_to_bitrates {
($name:ident, $func:ident) => {
#[test]
fn $name() {
$func::<1>();
$func::<2>();
$func::<4>();
$func::<8>();
}
};
}
fn create_random_vector(dim: usize, rng: &mut StdRng, low: f32, high: f32) -> Vec<f32> {
let distribution = Uniform::new_inclusive::<f32, f32>(low, high).unwrap();
let vector: Vec<f32> = distribution.sample_iter(rng).take(dim).collect();
vector
}
fn minmax_compress_vector<const N: usize>(vector: &[f32]) -> Vec<MinMaxElement<N>>
where
Unsigned: Representation<N>,
{
let transform =
Transform::Null(NullTransform::new(NonZeroUsize::new(vector.len()).unwrap()));
let quantizer = MinMaxQuantizer::new(transform, Positive::new(1.0).unwrap());
let mut bytes = vec![0_u8; DataRef::<N>::canonical_bytes(vector.len())];
let mut compressed =
DataMutRef::<N>::from_canonical_front_mut(&mut bytes, vector.len()).unwrap();
quantizer
.compress_into(vector, compressed.reborrow_mut())
.unwrap();
let slice = bytemuck::cast_slice::<u8, MinMaxElement<N>>(&bytes);
(*slice).into()
}
macro_rules! test_from_primitive {
($bitname:ident, $name:ident, $primitive:ty, $method:ident, unsigned) => {
test_from_primitive!($bitname, $name, $primitive, $method, []);
};
($bitname:ident, $name:ident, $primitive:ty, $method:ident, signed) => {
test_from_primitive!($bitname, $name, $primitive, $method, [-1, -100]);
};
($bitname:ident, $name:ident, $primitive:ty, $method:ident, [$($negative:expr),*]) => {
fn $bitname<const NBITS: usize>() {
let upper_range = (0x1 << NBITS) - 1;
let valid_values = [
0,
upper_range / 2, usize::min(upper_range, <$primitive>::MAX as usize), ];
for &val in &valid_values {
let result = MinMaxElement::<NBITS>::$method(val as $primitive);
assert_eq!(
result.unwrap().0,
val as u8,
"Failed for NBITS={}, val={}",
NBITS,
val
);
}
let overflow_values = [
upper_range + 1,
upper_range + 100,
256, 1000,
65536,
];
$(
let result = MinMaxElement::<NBITS>::$method($negative);
assert!(
result.is_none(),
"Expected None for NBITS={}, val={}",
NBITS,
$negative
);
)*
for &val in &overflow_values {
if val <= <$primitive>::MAX as usize {
let result = MinMaxElement::<NBITS>::$method(val as $primitive);
assert!(
result.is_none(),
"Expected None for NBITS={}, val={}",
NBITS,
val
);
}
}
}
expand_to_bitrates!($name, $bitname);
};
}
test_from_primitive!(from_u8_bits, from_u8, u8, from_u8, unsigned);
test_from_primitive!(from_i64_bits, from_i64, i64, from_i64, signed);
test_from_primitive!(from_u64_bits, from_u64, u64, from_u64, unsigned);
test_from_primitive!(from_i32_bits, from_i32, i32, from_i32, signed);
test_from_primitive!(from_u32_bits, from_u32, u32, from_u32, unsigned);
fn as_fn_pointer<T, const N: usize>() -> fn(DataRef<'_, N>, DataRef<'_, N>) -> f32
where
T: for<'a, 'b> PureDistanceFunction<
DataRef<'a, N>,
DataRef<'b, N>,
diskann_quantization::distances::Result<f32>,
>,
Unsigned: Representation<N>,
{
fn wrapper<U, const M: usize>(x: DataRef<'_, M>, y: DataRef<'_, M>) -> f32
where
U: for<'a, 'b> PureDistanceFunction<
DataRef<'a, M>,
DataRef<'b, M>,
diskann_quantization::distances::Result<f32>,
>,
Unsigned: Representation<M>,
{
U::evaluate(x, y).unwrap()
}
wrapper::<T, N>
}
fn distance_comparer<const N: usize>(
metric: Metric,
) -> fn(DataRef<'_, N>, DataRef<'_, N>) -> f32
where
Unsigned: Representation<N>,
InnerProduct: for<'a, 'b> PureDistanceFunction<
BitSlice<'a, N, Unsigned>,
BitSlice<'b, N, Unsigned>,
diskann_quantization::distances::MathematicalResult<u32>,
>,
{
match metric {
Metric::Cosine => as_fn_pointer::<MinMaxCosine, N>(),
Metric::InnerProduct => as_fn_pointer::<MinMaxIP, N>(),
Metric::L2 => as_fn_pointer::<MinMaxL2Squared, N>(),
Metric::CosineNormalized => as_fn_pointer::<MinMaxCosineNormalized, N>(),
}
}
fn test_distance<const N: usize>(
(v1, d1): (&[MinMaxElement<N>], usize),
(v2, d2): (&[MinMaxElement<N>], usize),
metric: Metric,
) where
Unsigned: Representation<N>,
InnerProduct: for<'a, 'b> PureDistanceFunction<
BitSlice<'a, N, Unsigned>,
BitSlice<'b, N, Unsigned>,
diskann_quantization::distances::MathematicalResult<u32>,
>,
{
let distance = MinMaxElement::<N>::distance(metric, Some(d1));
let query_distance = MinMaxElement::<N>::query_distance(v1, metric);
let v1_ref = DataRef::<'_, N>::from_canonical_front(
bytemuck::cast_slice::<MinMaxElement<N>, u8>(v1),
d1,
)
.unwrap();
let v2_ref = DataRef::<'_, N>::from_canonical_front(
bytemuck::cast_slice::<MinMaxElement<N>, u8>(v2),
d2,
)
.unwrap();
let dref = distance_comparer::<N>(metric)(v1_ref, v2_ref);
let d: f32 = distance.evaluate_similarity(v1, v2);
assert!(
(d - dref).abs() <= 1e-6,
"Distance function doesn't match reference"
);
let d: f32 = query_distance.evaluate_similarity(v2);
assert!(
(d - dref).abs() <= 1e-6,
"Distance function doesn't match reference"
);
}
fn test_mm_distance_fns_happy_bits<const N: usize>()
where
Unsigned: Representation<N>,
InnerProduct: for<'a, 'b> PureDistanceFunction<
BitSlice<'a, N, Unsigned>,
BitSlice<'b, N, Unsigned>,
diskann_quantization::distances::MathematicalResult<u32>,
>,
{
let dims = [
3, 13, 57, 128, 256, 384, 418, 511, 512, 768, 896, 1024, 1536, 3072,
];
let metrics = [
Metric::Cosine,
Metric::CosineNormalized,
Metric::InnerProduct,
Metric::L2,
];
let trials = 10;
let mut rng = create_rnd_from_seed_in_tests(0x4fa598591f6);
for dim in dims {
for _ in 0..trials {
for metric in metrics {
let v1 = create_random_vector(dim, &mut rng, -1.0, 1.0);
let v2 = create_random_vector(dim, &mut rng, -1.0, 1.0);
let v1 = minmax_compress_vector::<N>(&v1);
let v2 = minmax_compress_vector::<N>(&v2);
test_distance::<N>((&v1, dim), (&v2, dim), metric);
}
}
}
}
expand_to_bitrates!(test_mm_distance_fns_happy, test_mm_distance_fns_happy_bits);
fn test_mm_distance_fns_panic_bits<const N: usize>()
where
Unsigned: Representation<N>,
InnerProduct: for<'a, 'b> PureDistanceFunction<
BitSlice<'a, N, Unsigned>,
BitSlice<'b, N, Unsigned>,
diskann_quantization::distances::MathematicalResult<u32>,
>,
{
let dims = [4, 276, 380, 3180];
let metrics = [
Metric::L2,
Metric::CosineNormalized,
Metric::InnerProduct,
Metric::Cosine,
];
let mut rng = create_rnd_from_seed_in_tests(0x9ce578592f7);
for dim in dims {
for metric in metrics {
let v1 = create_random_vector(dim + 1, &mut rng, -1.0, 1.0);
let v2 = create_random_vector(dim, &mut rng, -1.0, 1.0);
let v1 = minmax_compress_vector::<N>(&v1);
let v2 = minmax_compress_vector::<N>(&v2);
let result = std::panic::catch_unwind(|| {
test_distance::<N>((&v1, dim + 1), (&v2, dim), metric);
});
assert!(
result.is_err(),
"Expected panic for dim {} and metric {:?}",
dim,
metric
);
}
}
}
expand_to_bitrates!(test_mm_distance_fns_panic, test_mm_distance_fns_panic_bits);
fn test_full_dimension_happy_bits<const NBITS: usize>()
where
MinMaxElement<NBITS>: VectorRepr<Error = MMConvertError>,
Unsigned: Representation<NBITS>,
{
let dims = [1, 2, 4, 8, 16, 32, 64, 128, 256];
let mut rng = create_rnd_from_seed_in_tests(0x9ce578592f7);
for dim in dims {
let vec_f32 = create_random_vector(dim, &mut rng, -1.0, 1.0);
let compressed = minmax_compress_vector::<NBITS>(&vec_f32);
let extracted_dim = MinMaxElement::<NBITS>::full_dimension(&compressed).unwrap();
assert_eq!(
extracted_dim, dim,
"Dimension extraction failed for dim {}",
dim
);
}
}
expand_to_bitrates!(test_full_dimension_happy, test_full_dimension_happy_bits);
fn test_full_dimension_error_cases_bits<const NBITS: usize>()
where
MinMaxElement<NBITS>: VectorRepr<Error = MMConvertError>,
Unsigned: Representation<NBITS>,
{
let short_slice = vec![MinMaxElement::<NBITS>::from_u8(1).unwrap(); 2]; let result = MinMaxElement::<NBITS>::full_dimension(&short_slice);
assert_eq!(
result.unwrap_err(),
MMConvertError::MetaParseError(MetaParseError::NotCanonical(2))
);
let empty_slice: Vec<MinMaxElement<NBITS>> = vec![];
let result = MinMaxElement::<NBITS>::full_dimension(&empty_slice);
assert_eq!(
result.unwrap_err(),
MMConvertError::MetaParseError(MetaParseError::NotCanonical(0))
);
}
expand_to_bitrates!(
test_full_dimension_error_cases,
test_full_dimension_error_cases_bits
);
fn test_as_f32_matches_decompress_into_bits<const NBITS: usize>()
where
MinMaxElement<NBITS>: VectorRepr<Error = MMConvertError>,
Unsigned: Representation<NBITS>,
{
let dimensions = [4, 32, 64, 73, 128, 384, 411, 896, 1536, 3072];
let mut rng = create_rnd_from_seed_in_tests(0x9ce578592f7);
for &dim in &dimensions {
let vec_f32 = create_random_vector(dim, &mut rng, -1.0, 1.0);
let compressed = minmax_compress_vector::<NBITS>(&vec_f32);
let pre = MinMaxElement::<NBITS>::as_f32(&compressed);
assert!(pre.is_ok());
let decompressed_as_f32 = pre.unwrap();
let mut buffer = vec![f32::default(); dim];
let pre = MinMaxElement::<NBITS>::as_f32_into(&compressed, &mut buffer);
assert!(pre.is_ok());
let pre = MinMaxElement::<NBITS>::from_raw(&compressed);
assert!(pre.is_ok());
let data_ref = pre.unwrap();
let mut decompressed_into = vec![0.0f32; dim];
let pre = data_ref.decompress_into(&mut decompressed_into);
assert!(pre.is_ok());
let results = [buffer.as_slice(), &decompressed_as_f32];
for v in results {
assert_eq!(v.len(), decompressed_into.len());
for (i, (&a, &b)) in v.iter().zip(decompressed_into.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-6,
"Mismatch at index {} for dim {}: as_f32={}, decompress_into={}",
i,
dim,
a,
b
);
}
}
}
}
expand_to_bitrates!(
test_as_f32_matches_decompress_into,
test_as_f32_matches_decompress_into_bits
);
fn test_as_f32_error_cases_bits<const NBITS: usize>()
where
MinMaxElement<NBITS>: VectorRepr<Error = MMConvertError>,
Unsigned: Representation<NBITS>,
{
let short_slice = vec![MinMaxElement::<NBITS>::from_u8(1).unwrap(); 5];
let result = MinMaxElement::<NBITS>::as_f32(&short_slice);
assert!(result.is_err(), "as_f32 should fail on too short slice");
let empty_slice: Vec<MinMaxElement<NBITS>> = vec![];
let result = MinMaxElement::<NBITS>::as_f32(&empty_slice);
assert!(result.is_err(), "as_f32 should fail on empty slice");
let mut invalid_slice = vec![0u8, 0_u8, 0_u8, 10u8];
invalid_slice.append(&mut vec![0u8; 30]);
let result = MinMaxElement::<NBITS>::as_f32(
bytemuck::cast_slice::<u8, MinMaxElement<NBITS>>(&invalid_slice),
);
assert!(result.is_err(), "as_f32 should fail on non-canonical data");
}
expand_to_bitrates!(test_as_f32_error_cases, test_as_f32_error_cases_bits);
fn test_as_f32_into_error_bits<const NBITS: usize>()
where
MinMaxElement<NBITS>: VectorRepr<Error = MMConvertError>,
Unsigned: Representation<NBITS>,
{
let dim = 10;
let mut rng = create_rnd_from_seed_in_tests(0x9ce578592f7);
let vec_f32 = create_random_vector(dim, &mut rng, -1.0, 1.0);
let compressed = minmax_compress_vector::<NBITS>(&vec_f32);
let mut small_buffer = vec![0.0f32; dim - 2];
let result = MinMaxElement::<NBITS>::as_f32_into(&compressed, &mut small_buffer);
assert_eq!(
result.unwrap_err(),
MMConvertError::WrongLength(dim, dim - 2)
);
let mut large_buffer = vec![0.0f32; dim + 5];
let result = MinMaxElement::<NBITS>::as_f32_into(&compressed, &mut large_buffer);
assert_eq!(
result.unwrap_err(),
MMConvertError::WrongLength(dim, dim + 5)
);
let invalid_slice = vec![MinMaxElement::<NBITS>::from_u8(1).unwrap(); 8];
let mut buffer = vec![0.0f32; 3];
let result = MinMaxElement::<NBITS>::as_f32_into(&invalid_slice, &mut buffer);
assert!(
result.is_err(),
"as_f32_into should fail with invalid data format"
);
let empty_slice: Vec<MinMaxElement<NBITS>> = vec![];
let mut buffer = vec![0.0f32; 5];
let result = MinMaxElement::<NBITS>::as_f32_into(&empty_slice, &mut buffer);
assert!(result.is_err(), "as_f32_into should fail with empty slice");
}
expand_to_bitrates!(test_as_f32_into_error, test_as_f32_into_error_bits);
}