use std::fmt::Debug;
use diskann_vector::{
DistanceFunction, PreprocessedDistanceFunction,
conversion::CastFromSlice,
distance::{DistanceProvider, Metric},
};
use half::f16;
use thiserror::Error;
use crate::{ANNError, internal::convert_f32::ConvertF32};
pub trait VectorElement:
Sized + bytemuck::Pod + num_traits::FromPrimitive + std::fmt::Debug + Default + Send + Sync
{
}
impl<T> VectorElement for T where
T: Sized + bytemuck::Pod + num_traits::FromPrimitive + std::fmt::Debug + Default + Send + Sync
{
}
pub trait VectorRepr: VectorElement {
type Error: std::error::Error + Debug + Send + Sync + Into<ANNError>;
type Distance: for<'a, 'b> DistanceFunction<&'a [Self], &'b [Self], f32>
+ Debug
+ Send
+ Sync
+ 'static;
type QueryDistance: for<'a> PreprocessedDistanceFunction<&'a [Self], f32>
+ Debug
+ Send
+ Sync
+ 'static;
fn full_dimension(vec: &[Self]) -> Result<usize, Self::Error>;
fn distance(metric: Metric, dim: Option<usize>) -> Self::Distance;
fn query_distance(query: &[Self], metric: Metric) -> Self::QueryDistance;
fn as_f32(data: &[Self]) -> Result<impl std::ops::Deref<Target = [f32]>, Self::Error>;
fn as_f32_into(src: &[Self], dst: &mut [f32]) -> Result<(), Self::Error>;
}
#[derive(Debug, Clone, PartialEq, Error)]
#[error("Unable to set full-precision vector of length {src} into slice of length {dst}")]
pub struct NativeTypeLengthError {
src: usize,
dst: usize,
}
impl From<NativeTypeLengthError> for ANNError {
fn from(err: NativeTypeLengthError) -> ANNError {
ANNError::log_index_error(format!(
"Unable to set full-precision vector of length {} into slice of length {}",
err.src, err.dst
))
}
}
macro_rules! default_impl {
(
$T:ty,
QueryDistance = $QueryDistance:ty,
query_impl = $query_impl:expr,
into_impl = $into_impl:expr
) => {
impl VectorRepr for $T {
type Error = NativeTypeLengthError;
type Distance = diskann_vector::distance::Distance<$T, $T>;
type QueryDistance = $QueryDistance;
fn distance(metric: Metric, dim: Option<usize>) -> Self::Distance {
<$T>::distance_comparer(metric, dim)
}
fn query_distance(query: &[$T], metric: Metric) -> Self::QueryDistance {
($query_impl)(query, metric)
}
fn full_dimension(v: &[Self]) -> Result<usize, Self::Error> {
Ok(v.len())
}
fn as_f32(data: &[$T]) -> Result<impl std::ops::Deref<Target = [f32]>, Self::Error> {
Ok(data.convert_f32())
}
fn as_f32_into(src: &[Self], dst: &mut [f32]) -> Result<(), Self::Error> {
if dst.len() != src.len() {
return Err(NativeTypeLengthError{src: src.len(), dst: dst.len()});
}
($into_impl)(src, dst);
Ok(())
}
}
};
($T:ty) => {
default_impl!(
$T,
QueryDistance = BufferedDistance<$T>,
query_impl = |query : &[$T], metric| {
BufferedDistance::new(query.into(), metric)
},
into_impl = |src : &[$T], dst : &mut [f32]| {
for (d, x) in dst.iter_mut().zip(src.iter()) {
*d = (*x).into();
}
}
);
};
}
default_impl!(i8);
default_impl!(u8);
default_impl!(
f32,
QueryDistance = BufferedDistance<f32>,
query_impl = |query: &[f32], metric| BufferedDistance::new(query.into(), metric),
into_impl = |src : &[f32], dst: &mut [f32]| {
dst.copy_from_slice(src);
}
);
default_impl!(
f16,
QueryDistance = BufferedDistance<f16, f32>,
query_impl = |query: &[f16], metric| {
let mut converted: Box<[f32]> = (0..query.len()).map(|_| f32::default()).collect();
converted.cast_from_slice(query);
BufferedDistance::new(converted, metric)
},
into_impl = |src: &[f16], dst: &mut [f32]| {
dst.cast_from_slice(src);
}
);
#[derive(Debug)]
pub struct BufferedDistance<T, U = T>
where
U: 'static,
T: 'static,
{
query: Box<[U]>,
f: diskann_vector::distance::Distance<U, T>,
}
impl<T, U> BufferedDistance<T, U> {
pub fn new(query: Box<[U]>, metric: Metric) -> Self
where
U: DistanceProvider<T>,
{
let dim = query.len();
Self {
query,
f: U::distance_comparer(metric, Some(dim)),
}
}
}
impl<T, U> PreprocessedDistanceFunction<&[T]> for BufferedDistance<T, U> {
#[inline(always)]
fn evaluate_similarity(&self, x: &[T]) -> f32 {
self.f.call(&self.query, x)
}
}
#[cfg(test)]
mod tests {
use std::marker::PhantomData;
use diskann_vector::Half;
use super::*;
fn implements_vector_element<T: VectorElement>() -> bool {
let _ = PhantomData::<T>;
true
}
#[test]
fn test_vector_element() {
assert!(implements_vector_element::<Half>());
assert!(implements_vector_element::<f32>());
assert!(implements_vector_element::<i8>());
assert!(implements_vector_element::<i16>());
assert!(implements_vector_element::<u8>());
assert!(implements_vector_element::<u16>());
}
}