use diskann_vector::PureDistanceFunction;
use diskann_wide::{ARCH, Architecture, arch::Target2};
#[cfg(target_arch = "x86_64")]
use diskann_wide::{
SIMDCast, SIMDDotProduct, SIMDMulAdd, SIMDReinterpret, SIMDSumTree, SIMDVector,
};
use super::{Binary, BitSlice, BitTranspose, Dense, Representation, Unsigned};
use crate::distances::{Hamming, InnerProduct, MV, MathematicalResult, SquaredL2, check_lengths};
type USlice<'a, const N: usize, Perm = Dense> = BitSlice<'a, N, Unsigned, Perm>;
macro_rules! retarget {
($arch:path, $op:ty, ($N:literal, $M:literal)) => {
impl Target2<
$arch,
MathematicalResult<u32>,
USlice<'_, $N>,
USlice<'_, $M>,
> for $op {
#[inline(always)]
fn run(
self,
arch: $arch,
x: USlice<'_, $N>,
y: USlice<'_, $M>
) -> MathematicalResult<u32> {
self.run(arch.retarget(), x, y)
}
}
};
($arch:path, $op:ty, $N:literal) => {
retarget!($arch, $op, ($N, $N));
};
($arch:path, $op:ty, $($args:tt),+ $(,)?) => {
$(retarget!($arch, $op, $args);)+
};
}
macro_rules! dispatch_pure {
($op:ty, ($N:literal, $M:literal)) => {
impl PureDistanceFunction<USlice<'_, $N>, USlice<'_, $M>, MathematicalResult<u32>> for $op {
#[inline(always)]
fn evaluate(x: USlice<'_, $N>, y: USlice<'_, $M>) -> MathematicalResult<u32> {
(diskann_wide::ARCH).run2(Self, x, y)
}
}
};
($op:ty, $N:literal) => {
dispatch_pure!($op, ($N, $N));
};
($op:ty, $($args:tt),+ $(,)?) => {
$(dispatch_pure!($op, $args);)+
}
}
#[cfg(target_arch = "x86_64")]
unsafe fn load_one<F, R>(ptr: *const u32, mut f: F) -> R
where
F: FnMut(u32) -> R,
{
f(unsafe { ptr.cast::<u8>().read_unaligned() }.into())
}
#[cfg(target_arch = "x86_64")]
unsafe fn load_two<F, R>(ptr: *const u32, mut f: F) -> R
where
F: FnMut(u32) -> R,
{
f(unsafe { ptr.cast::<u16>().read_unaligned() }.into())
}
#[cfg(target_arch = "x86_64")]
unsafe fn load_three<F, R>(ptr: *const u32, mut f: F) -> R
where
F: FnMut(u32) -> R,
{
let lo: u32 = unsafe { ptr.cast::<u16>().read_unaligned() }.into();
let hi: u32 = unsafe { ptr.cast::<u8>().add(2).read_unaligned() }.into();
f(lo | hi << 16)
}
#[cfg(target_arch = "x86_64")]
unsafe fn load_four<F, R>(ptr: *const u32, mut f: F) -> R
where
F: FnMut(u32) -> R,
{
f(unsafe { ptr.read_unaligned() })
}
trait BitVectorOp<Repr>
where
Repr: Representation<1>,
{
fn on_u64(x: u64, y: u64) -> u32;
fn on_u8(x: u8, y: u8) -> u32;
}
impl BitVectorOp<Unsigned> for SquaredL2 {
#[inline(always)]
fn on_u64(x: u64, y: u64) -> u32 {
(x ^ y).count_ones()
}
#[inline(always)]
fn on_u8(x: u8, y: u8) -> u32 {
(x ^ y).count_ones()
}
}
impl BitVectorOp<Binary> for Hamming {
#[inline(always)]
fn on_u64(x: u64, y: u64) -> u32 {
(x ^ y).count_ones()
}
#[inline(always)]
fn on_u8(x: u8, y: u8) -> u32 {
(x ^ y).count_ones()
}
}
impl BitVectorOp<Unsigned> for InnerProduct {
#[inline(always)]
fn on_u64(x: u64, y: u64) -> u32 {
(x & y).count_ones()
}
#[inline(always)]
fn on_u8(x: u8, y: u8) -> u32 {
(x & y).count_ones()
}
}
#[inline(always)]
fn bitvector_op<Op, Repr>(
x: BitSlice<'_, 1, Repr>,
y: BitSlice<'_, 1, Repr>,
) -> MathematicalResult<u32>
where
Repr: Representation<1>,
Op: BitVectorOp<Repr>,
{
let len = check_lengths!(x, y)?;
let px: *const u64 = x.as_ptr().cast();
let py: *const u64 = y.as_ptr().cast();
let mut i = 0;
let mut s: u32 = 0;
let blocks = len / 64;
while i < blocks {
let vx = unsafe { px.add(i).read_unaligned() };
let vy = unsafe { py.add(i).read_unaligned() };
s += Op::on_u64(vx, vy);
i += 1;
}
i *= 8;
let px: *const u8 = x.as_ptr();
let py: *const u8 = y.as_ptr();
let blocks = len / 8;
while i < blocks {
let vx = unsafe { px.add(i).read_unaligned() };
let vy = unsafe { py.add(i).read_unaligned() };
s += Op::on_u8(vx, vy);
i += 1;
}
if i * 8 != len {
let vx = unsafe { px.add(i).read_unaligned() };
let vy = unsafe { py.add(i).read_unaligned() };
let m = (0x01u8 << (len - 8 * i)) - 1;
s += Op::on_u8(vx & m, vy & m)
}
Ok(MV::new(s))
}
impl PureDistanceFunction<BitSlice<'_, 1, Binary>, BitSlice<'_, 1, Binary>, MathematicalResult<u32>>
for Hamming
{
fn evaluate(x: BitSlice<'_, 1, Binary>, y: BitSlice<'_, 1, Binary>) -> MathematicalResult<u32> {
bitvector_op::<Hamming, Binary>(x, y)
}
}
impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 8>> for SquaredL2
where
A: Architecture,
diskann_vector::distance::SquaredL2: for<'a> Target2<A, MV<f32>, &'a [u8], &'a [u8]>,
{
#[inline(always)]
fn run(self, arch: A, x: USlice<'_, 8>, y: USlice<'_, 8>) -> MathematicalResult<u32> {
check_lengths!(x, y)?;
let r: MV<f32> = <_ as Target2<_, _, _, _>>::run(
diskann_vector::distance::SquaredL2 {},
arch,
x.as_slice(),
y.as_slice(),
);
Ok(MV::new(r.into_inner() as u32))
}
}
#[cfg(target_arch = "x86_64")]
impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 4>, USlice<'_, 4>>
for SquaredL2
{
#[inline(always)]
fn run(
self,
arch: diskann_wide::arch::x86_64::V3,
x: USlice<'_, 4>,
y: USlice<'_, 4>,
) -> MathematicalResult<u32> {
let len = check_lengths!(x, y)?;
diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
let px_u32: *const u32 = x.as_ptr().cast();
let py_u32: *const u32 = y.as_ptr().cast();
let mut i = 0;
let mut s: u32 = 0;
let blocks = len / 8;
if i < blocks {
let mut s0 = i32s::default(arch);
let mut s1 = i32s::default(arch);
let mut s2 = i32s::default(arch);
let mut s3 = i32s::default(arch);
let mask = u32s::splat(arch, 0x000f000f);
while i + 8 < blocks {
let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
let wx: i16s = (vx & mask).reinterpret_simd();
let wy: i16s = (vy & mask).reinterpret_simd();
let d = wx - wy;
s0 = s0.dot_simd(d, d);
let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
let d = wx - wy;
s1 = s1.dot_simd(d, d);
let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
let d = wx - wy;
s2 = s2.dot_simd(d, d);
let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
let d = wx - wy;
s3 = s3.dot_simd(d, d);
i += 8;
}
let remainder = blocks - i;
let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
let wx: i16s = (vx & mask).reinterpret_simd();
let wy: i16s = (vy & mask).reinterpret_simd();
let d = wx - wy;
s0 = s0.dot_simd(d, d);
let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
let d = wx - wy;
s1 = s1.dot_simd(d, d);
let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
let d = wx - wy;
s2 = s2.dot_simd(d, d);
let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
let d = wx - wy;
s3 = s3.dot_simd(d, d);
i += remainder;
s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
}
i *= 8;
if i != len {
#[inline(never)]
fn fallback(x: USlice<'_, 4>, y: USlice<'_, 4>, from: usize) -> u32 {
let mut s: i32 = 0;
for i in from..x.len() {
let ix = unsafe { x.get_unchecked(i) } as i32;
let iy = unsafe { y.get_unchecked(i) } as i32;
let d = ix - iy;
s += d * d;
}
s as u32
}
s += fallback(x, y, i);
}
Ok(MV::new(s))
}
}
#[cfg(target_arch = "x86_64")]
impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 2>, USlice<'_, 2>>
for SquaredL2
{
#[inline(always)]
fn run(
self,
arch: diskann_wide::arch::x86_64::V3,
x: USlice<'_, 2>,
y: USlice<'_, 2>,
) -> MathematicalResult<u32> {
let len = check_lengths!(x, y)?;
diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
let px_u32: *const u32 = x.as_ptr().cast();
let py_u32: *const u32 = y.as_ptr().cast();
let mut i = 0;
let mut s: u32 = 0;
let blocks = len / 16;
if i < blocks {
let mut s0 = i32s::default(arch);
let mut s1 = i32s::default(arch);
let mut s2 = i32s::default(arch);
let mut s3 = i32s::default(arch);
let mask = u32s::splat(arch, 0x00030003);
while i + 8 < blocks {
let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
let wx: i16s = (vx & mask).reinterpret_simd();
let wy: i16s = (vy & mask).reinterpret_simd();
let d = wx - wy;
s0 = s0.dot_simd(d, d);
let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
let d = wx - wy;
s1 = s1.dot_simd(d, d);
let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
let d = wx - wy;
s2 = s2.dot_simd(d, d);
let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
let d = wx - wy;
s3 = s3.dot_simd(d, d);
let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
let d = wx - wy;
s0 = s0.dot_simd(d, d);
let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
let d = wx - wy;
s1 = s1.dot_simd(d, d);
let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
let d = wx - wy;
s2 = s2.dot_simd(d, d);
let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
let d = wx - wy;
s3 = s3.dot_simd(d, d);
i += 8;
}
let remainder = blocks - i;
let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
let wx: i16s = (vx & mask).reinterpret_simd();
let wy: i16s = (vy & mask).reinterpret_simd();
let d = wx - wy;
s0 = s0.dot_simd(d, d);
let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
let d = wx - wy;
s1 = s1.dot_simd(d, d);
let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
let d = wx - wy;
s2 = s2.dot_simd(d, d);
let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
let d = wx - wy;
s3 = s3.dot_simd(d, d);
let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
let d = wx - wy;
s0 = s0.dot_simd(d, d);
let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
let d = wx - wy;
s1 = s1.dot_simd(d, d);
let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
let d = wx - wy;
s2 = s2.dot_simd(d, d);
let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
let d = wx - wy;
s3 = s3.dot_simd(d, d);
i += remainder;
s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
}
i *= 16;
if i != len {
#[inline(never)]
fn fallback(x: USlice<'_, 2>, y: USlice<'_, 2>, from: usize) -> u32 {
let mut s: i32 = 0;
for i in from..x.len() {
let ix = unsafe { x.get_unchecked(i) } as i32;
let iy = unsafe { y.get_unchecked(i) } as i32;
let d = ix - iy;
s += d * d;
}
s as u32
}
s += fallback(x, y, i);
}
Ok(MV::new(s))
}
}
impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 1>, USlice<'_, 1>> for SquaredL2
where
A: Architecture,
{
fn run(self, _: A, x: USlice<'_, 1>, y: USlice<'_, 1>) -> MathematicalResult<u32> {
bitvector_op::<Self, Unsigned>(x, y)
}
}
macro_rules! impl_fallback_l2 {
($N:literal) => {
impl Target2<diskann_wide::arch::Scalar, MathematicalResult<u32>, USlice<'_, $N>, USlice<'_, $N>> for SquaredL2 {
#[inline(never)]
fn run(
self,
_: diskann_wide::arch::Scalar,
x: USlice<'_, $N>,
y: USlice<'_, $N>
) -> MathematicalResult<u32> {
let len = check_lengths!(x, y)?;
let mut accum: i32 = 0;
for i in 0..len {
let ix: i32 = unsafe { x.get_unchecked(i) } as i32;
let iy: i32 = unsafe { y.get_unchecked(i) } as i32;
let diff = ix - iy;
accum += diff * diff;
}
Ok(MV::new(accum as u32))
}
}
};
($($N:literal),+ $(,)?) => {
$(impl_fallback_l2!($N);)+
};
}
impl_fallback_l2!(7, 6, 5, 4, 3, 2);
#[cfg(target_arch = "x86_64")]
retarget!(diskann_wide::arch::x86_64::V3, SquaredL2, 7, 6, 5, 3);
#[cfg(target_arch = "x86_64")]
retarget!(diskann_wide::arch::x86_64::V4, SquaredL2, 7, 6, 5, 4, 3, 2);
dispatch_pure!(SquaredL2, 1, 2, 3, 4, 5, 6, 7, 8);
#[cfg(target_arch = "aarch64")]
retarget!(
diskann_wide::arch::aarch64::Neon,
SquaredL2,
7,
6,
5,
4,
3,
2
);
impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 8>> for InnerProduct
where
A: Architecture,
diskann_vector::distance::InnerProduct: for<'a> Target2<A, MV<f32>, &'a [u8], &'a [u8]>,
{
#[inline(always)]
fn run(self, arch: A, x: USlice<'_, 8>, y: USlice<'_, 8>) -> MathematicalResult<u32> {
check_lengths!(x, y)?;
let r: MV<f32> = <_ as Target2<_, _, _, _>>::run(
diskann_vector::distance::InnerProduct {},
arch,
x.as_slice(),
y.as_slice(),
);
Ok(MV::new(r.into_inner() as u32))
}
}
#[cfg(target_arch = "x86_64")]
impl Target2<diskann_wide::arch::x86_64::V4, MathematicalResult<u32>, USlice<'_, 2>, USlice<'_, 2>>
for InnerProduct
{
#[expect(non_camel_case_types)]
#[inline(always)]
fn run(
self,
arch: diskann_wide::arch::x86_64::V4,
x: USlice<'_, 2>,
y: USlice<'_, 2>,
) -> MathematicalResult<u32> {
let len = check_lengths!(x, y)?;
type i32s = <diskann_wide::arch::x86_64::V4 as Architecture>::i32x16;
type u32s = <diskann_wide::arch::x86_64::V4 as Architecture>::u32x16;
type u8s = <diskann_wide::arch::x86_64::V4 as Architecture>::u8x64;
type i8s = <diskann_wide::arch::x86_64::V4 as Architecture>::i8x64;
let px_u32: *const u32 = x.as_ptr().cast();
let py_u32: *const u32 = y.as_ptr().cast();
let mut i = 0;
let mut s: u32 = 0;
let blocks = len.div_ceil(16);
if i < blocks {
let mut s0 = i32s::default(arch);
let mut s1 = i32s::default(arch);
let mut s2 = i32s::default(arch);
let mut s3 = i32s::default(arch);
let mask = u32s::splat(arch, 0x03030303);
while i + 16 < blocks {
let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
let wx: u8s = (vx & mask).reinterpret_simd();
let wy: i8s = (vy & mask).reinterpret_simd();
s0 = s0.dot_simd(wx, wy);
let wx: u8s = ((vx >> 2) & mask).reinterpret_simd();
let wy: i8s = ((vy >> 2) & mask).reinterpret_simd();
s1 = s1.dot_simd(wx, wy);
let wx: u8s = ((vx >> 4) & mask).reinterpret_simd();
let wy: i8s = ((vy >> 4) & mask).reinterpret_simd();
s2 = s2.dot_simd(wx, wy);
let wx: u8s = ((vx >> 6) & mask).reinterpret_simd();
let wy: i8s = ((vy >> 6) & mask).reinterpret_simd();
s3 = s3.dot_simd(wx, wy);
i += 16;
}
let remainder = len / 4 - 4 * i;
let vx = unsafe { u8s::load_simd_first(arch, px_u32.add(i).cast::<u8>(), remainder) };
let vx: u32s = vx.reinterpret_simd();
let vy = unsafe { u8s::load_simd_first(arch, py_u32.add(i).cast::<u8>(), remainder) };
let vy: u32s = vy.reinterpret_simd();
let wx: u8s = (vx & mask).reinterpret_simd();
let wy: i8s = (vy & mask).reinterpret_simd();
s0 = s0.dot_simd(wx, wy);
let wx: u8s = ((vx >> 2) & mask).reinterpret_simd();
let wy: i8s = ((vy >> 2) & mask).reinterpret_simd();
s1 = s1.dot_simd(wx, wy);
let wx: u8s = ((vx >> 4) & mask).reinterpret_simd();
let wy: i8s = ((vy >> 4) & mask).reinterpret_simd();
s2 = s2.dot_simd(wx, wy);
let wx: u8s = ((vx >> 6) & mask).reinterpret_simd();
let wy: i8s = ((vy >> 6) & mask).reinterpret_simd();
s3 = s3.dot_simd(wx, wy);
s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
i = (4 * i) + remainder;
}
i *= 4;
debug_assert!(len - i <= 3);
let rest = (len - i).min(3);
if i != len {
for j in 0..rest {
let ix = unsafe { x.get_unchecked(i + j) } as u32;
let iy = unsafe { y.get_unchecked(i + j) } as u32;
s += ix * iy;
}
}
Ok(MV::new(s))
}
}
#[cfg(target_arch = "x86_64")]
impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 4>, USlice<'_, 4>>
for InnerProduct
{
#[inline(always)]
fn run(
self,
arch: diskann_wide::arch::x86_64::V3,
x: USlice<'_, 4>,
y: USlice<'_, 4>,
) -> MathematicalResult<u32> {
let len = check_lengths!(x, y)?;
diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
let px_u32: *const u32 = x.as_ptr().cast();
let py_u32: *const u32 = y.as_ptr().cast();
let mut i = 0;
let mut s: u32 = 0;
let blocks = len / 8;
if i < blocks {
let mut s0 = i32s::default(arch);
let mut s1 = i32s::default(arch);
let mut s2 = i32s::default(arch);
let mut s3 = i32s::default(arch);
let mask = u32s::splat(arch, 0x000f000f);
while i + 8 < blocks {
let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
let wx: i16s = (vx & mask).reinterpret_simd();
let wy: i16s = (vy & mask).reinterpret_simd();
s0 = s0.dot_simd(wx, wy);
let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
s1 = s1.dot_simd(wx, wy);
let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
s2 = s2.dot_simd(wx, wy);
let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
s3 = s3.dot_simd(wx, wy);
i += 8;
}
let remainder = blocks - i;
let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
let wx: i16s = (vx & mask).reinterpret_simd();
let wy: i16s = (vy & mask).reinterpret_simd();
s0 = s0.dot_simd(wx, wy);
let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
s1 = s1.dot_simd(wx, wy);
let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
s2 = s2.dot_simd(wx, wy);
let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
s3 = s3.dot_simd(wx, wy);
i += remainder;
s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
}
i *= 8;
if i != len {
#[inline(never)]
fn fallback(x: USlice<'_, 4>, y: USlice<'_, 4>, from: usize) -> u32 {
let mut s: u32 = 0;
for i in from..x.len() {
let ix = unsafe { x.get_unchecked(i) } as u32;
let iy = unsafe { y.get_unchecked(i) } as u32;
s += ix * iy;
}
s
}
s += fallback(x, y, i);
}
Ok(MV::new(s))
}
}
#[cfg(target_arch = "x86_64")]
impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 2>, USlice<'_, 2>>
for InnerProduct
{
#[inline(always)]
fn run(
self,
arch: diskann_wide::arch::x86_64::V3,
x: USlice<'_, 2>,
y: USlice<'_, 2>,
) -> MathematicalResult<u32> {
let len = check_lengths!(x, y)?;
diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
let px_u32: *const u32 = x.as_ptr().cast();
let py_u32: *const u32 = y.as_ptr().cast();
let mut i = 0;
let mut s: u32 = 0;
let blocks = len / 16;
if i < blocks {
let mut s0 = i32s::default(arch);
let mut s1 = i32s::default(arch);
let mut s2 = i32s::default(arch);
let mut s3 = i32s::default(arch);
let mask = u32s::splat(arch, 0x00030003);
while i + 8 < blocks {
let vx = unsafe { u32s::load_simd(arch, px_u32.add(i)) };
let vy = unsafe { u32s::load_simd(arch, py_u32.add(i)) };
let wx: i16s = (vx & mask).reinterpret_simd();
let wy: i16s = (vy & mask).reinterpret_simd();
s0 = s0.dot_simd(wx, wy);
let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
s1 = s1.dot_simd(wx, wy);
let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
s2 = s2.dot_simd(wx, wy);
let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
s3 = s3.dot_simd(wx, wy);
let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
s0 = s0.dot_simd(wx, wy);
let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
s1 = s1.dot_simd(wx, wy);
let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
s2 = s2.dot_simd(wx, wy);
let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
s3 = s3.dot_simd(wx, wy);
i += 8;
}
let remainder = blocks - i;
let vx = unsafe { u32s::load_simd_first(arch, px_u32.add(i), remainder) };
let vy = unsafe { u32s::load_simd_first(arch, py_u32.add(i), remainder) };
let wx: i16s = (vx & mask).reinterpret_simd();
let wy: i16s = (vy & mask).reinterpret_simd();
s0 = s0.dot_simd(wx, wy);
let wx: i16s = (vx >> 2 & mask).reinterpret_simd();
let wy: i16s = (vy >> 2 & mask).reinterpret_simd();
s1 = s1.dot_simd(wx, wy);
let wx: i16s = (vx >> 4 & mask).reinterpret_simd();
let wy: i16s = (vy >> 4 & mask).reinterpret_simd();
s2 = s2.dot_simd(wx, wy);
let wx: i16s = (vx >> 6 & mask).reinterpret_simd();
let wy: i16s = (vy >> 6 & mask).reinterpret_simd();
s3 = s3.dot_simd(wx, wy);
let wx: i16s = (vx >> 8 & mask).reinterpret_simd();
let wy: i16s = (vy >> 8 & mask).reinterpret_simd();
s0 = s0.dot_simd(wx, wy);
let wx: i16s = (vx >> 10 & mask).reinterpret_simd();
let wy: i16s = (vy >> 10 & mask).reinterpret_simd();
s1 = s1.dot_simd(wx, wy);
let wx: i16s = (vx >> 12 & mask).reinterpret_simd();
let wy: i16s = (vy >> 12 & mask).reinterpret_simd();
s2 = s2.dot_simd(wx, wy);
let wx: i16s = (vx >> 14 & mask).reinterpret_simd();
let wy: i16s = (vy >> 14 & mask).reinterpret_simd();
s3 = s3.dot_simd(wx, wy);
i += remainder;
s = ((s0 + s1) + (s2 + s3)).sum_tree() as u32;
}
i *= 16;
if i != len {
#[inline(never)]
fn fallback(x: USlice<'_, 2>, y: USlice<'_, 2>, from: usize) -> u32 {
let mut s: u32 = 0;
for i in from..x.len() {
let ix = unsafe { x.get_unchecked(i) } as u32;
let iy = unsafe { y.get_unchecked(i) } as u32;
s += ix * iy;
}
s
}
s += fallback(x, y, i);
}
Ok(MV::new(s))
}
}
impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 1>, USlice<'_, 1>> for InnerProduct
where
A: Architecture,
{
#[inline(always)]
fn run(self, _: A, x: USlice<'_, 1>, y: USlice<'_, 1>) -> MathematicalResult<u32> {
bitvector_op::<Self, Unsigned>(x, y)
}
}
macro_rules! impl_fallback_ip {
(($N:literal, $M:literal)) => {
impl Target2<diskann_wide::arch::Scalar, MathematicalResult<u32>, USlice<'_, $N>, USlice<'_, $M>> for InnerProduct {
#[inline(never)]
fn run(
self,
_: diskann_wide::arch::Scalar,
x: USlice<'_, $N>,
y: USlice<'_, $M>
) -> MathematicalResult<u32> {
let len = check_lengths!(x, y)?;
let mut accum: u32 = 0;
for i in 0..len {
let ix = unsafe { x.get_unchecked(i) } as u32;
let iy = unsafe { y.get_unchecked(i) } as u32;
accum += ix * iy;
}
Ok(MV::new(accum))
}
}
};
($N:literal) => {
impl_fallback_ip!(($N, $N));
};
($($args:tt),+ $(,)?) => {
$(impl_fallback_ip!($args);)+
};
}
impl_fallback_ip!(7, 6, 5, 4, 3, 2, (8, 4), (8, 2), (8, 1));
#[cfg(target_arch = "x86_64")]
retarget!(diskann_wide::arch::x86_64::V3, InnerProduct, 7, 6, 5, 3);
#[cfg(target_arch = "x86_64")]
retarget!(
diskann_wide::arch::x86_64::V4,
InnerProduct,
7,
6,
5,
4,
3,
(8, 4),
(8, 2),
(8, 1)
);
dispatch_pure!(
InnerProduct,
1,
2,
3,
4,
5,
6,
7,
(8, 8),
(8, 4),
(8, 2),
(8, 1)
);
#[cfg(target_arch = "x86_64")]
impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 4>>
for InnerProduct
{
#[inline(always)]
fn run(
self,
arch: diskann_wide::arch::x86_64::V3,
x: USlice<'_, 8>,
y: USlice<'_, 4>,
) -> MathematicalResult<u32> {
use std::arch::x86_64::_mm256_maddubs_epi16;
let len = check_lengths!(x, y)?;
diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
diskann_wide::alias!(u8s_16 = <diskann_wide::arch::x86_64::V3>::u8x16);
diskann_wide::alias!(u8s_32 = <diskann_wide::arch::x86_64::V3>::u8x32);
diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
let px: *const u8 = x.as_ptr();
let py: *const u8 = y.as_ptr();
let mut i: usize = 0;
let mut s: u32 = 0;
#[inline(always)]
fn unpack_half(input: u8s_16) -> u8s_32 {
let combined = diskann_wide::LoHi::new(input, input >> 4).zip::<u8s_32>();
combined & u8s_32::splat(input.arch(), (1u8 << 4) - 1)
}
let blocks = len / 32;
if blocks > 0 {
let mut acc = i32s::default(arch);
let products = |x: u8s_32, y: u8s_32| -> i16s {
i16s::from_underlying(arch, unsafe {
_mm256_maddubs_epi16(x.to_underlying(), y.to_underlying())
})
};
let ones = i16s::splat(arch, 1);
while i + 4 <= blocks {
let vx = unsafe { u8s_32::load_simd(arch, px.add(32 * i)) };
let vy = unsafe { u8s_16::load_simd(arch, py.add(16 * i)) };
let m0 = products(vx, unpack_half(vy));
let vx = unsafe { u8s_32::load_simd(arch, px.add(32 * (i + 1))) };
let vy = unsafe { u8s_16::load_simd(arch, py.add(16 * (i + 1))) };
let m1 = products(vx, unpack_half(vy));
let vx = unsafe { u8s_32::load_simd(arch, px.add(32 * (i + 2))) };
let vy = unsafe { u8s_16::load_simd(arch, py.add(16 * (i + 2))) };
let m2 = products(vx, unpack_half(vy));
let vx = unsafe { u8s_32::load_simd(arch, px.add(32 * (i + 3))) };
let vy = unsafe { u8s_16::load_simd(arch, py.add(16 * (i + 3))) };
let m3 = products(vx, unpack_half(vy));
acc = acc.dot_simd(m0 + m1 + m2 + m3, ones);
i += 4;
}
while i < blocks {
let vx = unsafe { u8s_32::load_simd(arch, px.add(32 * i)) };
let vy = unsafe { u8s_16::load_simd(arch, py.add(16 * i)) };
acc = acc.dot_simd(products(vx, unpack_half(vy)), ones);
i += 1;
}
s = acc.sum_tree() as u32;
}
i *= 32;
if i != len {
#[inline(never)]
fn fallback(x: USlice<'_, 8>, y: USlice<'_, 4>, from: usize) -> u32 {
let mut s: u32 = 0;
for i in from..x.len() {
let ix = unsafe { x.get_unchecked(i) } as u32;
let iy = unsafe { y.get_unchecked(i) } as u32;
s += ix * iy;
}
s
}
s += fallback(x, y, i);
}
Ok(MV::new(s))
}
}
#[cfg(target_arch = "x86_64")]
impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 2>>
for InnerProduct
{
#[inline(always)]
fn run(
self,
arch: diskann_wide::arch::x86_64::V3,
x: USlice<'_, 8>,
y: USlice<'_, 2>,
) -> MathematicalResult<u32> {
use diskann_wide::SplitJoin;
use std::arch::x86_64::_mm256_maddubs_epi16;
let len = check_lengths!(x, y)?;
diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
diskann_wide::alias!(u8s_16 = <diskann_wide::arch::x86_64::V3>::u8x16);
diskann_wide::alias!(u8s_32 = <diskann_wide::arch::x86_64::V3>::u8x32);
diskann_wide::alias!(i16s = <diskann_wide::arch::x86_64::V3>::i16x16);
let px: *const u8 = x.as_ptr();
let py: *const u8 = y.as_ptr();
let mut i: usize = 0;
let mut s: u32 = 0;
let blocks = len / 64;
if blocks > 0 {
let mut acc = i32s::default(arch);
let products = |x: u8s_32, y: u8s_32| -> i16s {
i16s::from_underlying(arch, unsafe {
_mm256_maddubs_epi16(x.to_underlying(), y.to_underlying())
})
};
#[inline(always)]
fn unpack_sub<const N: u8>(input: u8s_16) -> u8s_32 {
let combined = diskann_wide::LoHi::new(input, input >> N).zip::<u8s_32>();
combined & u8s_32::splat(input.arch(), (1u8 << N) - 1)
}
let unpack_crumbs = |x: u8s_16| -> (u8s_32, u8s_32) {
let nibbles = unpack_sub::<4>(x);
let diskann_wide::LoHi { lo, hi } = nibbles.split();
let lower = unpack_sub::<2>(lo);
let upper = unpack_sub::<2>(hi);
(lower, upper)
};
let ones = i16s::splat(arch, 1);
while i + 4 <= blocks {
let (vx0, vx1, (vy0, vy1)) = unsafe {
(
u8s_32::load_simd(arch, px.add(64 * i)),
u8s_32::load_simd(arch, px.add(64 * i + 32)),
unpack_crumbs(u8s_16::load_simd(arch, py.add(16 * i))),
)
};
let m0a = products(vx0, vy0);
let m0b = products(vx1, vy1);
let (vx0, vx1, (vy0, vy1)) = unsafe {
(
u8s_32::load_simd(arch, px.add(64 * (i + 1))),
u8s_32::load_simd(arch, px.add(64 * (i + 1) + 32)),
unpack_crumbs(u8s_16::load_simd(arch, py.add(16 * (i + 1)))),
)
};
let m1a = products(vx0, vy0);
let m1b = products(vx1, vy1);
let (vx0, vx1, (vy0, vy1)) = unsafe {
(
u8s_32::load_simd(arch, px.add(64 * (i + 2))),
u8s_32::load_simd(arch, px.add(64 * (i + 2) + 32)),
unpack_crumbs(u8s_16::load_simd(arch, py.add(16 * (i + 2)))),
)
};
let m2a = products(vx0, vy0);
let m2b = products(vx1, vy1);
let (vx0, vx1, (vy0, vy1)) = unsafe {
(
u8s_32::load_simd(arch, px.add(64 * (i + 3))),
u8s_32::load_simd(arch, px.add(64 * (i + 3) + 32)),
unpack_crumbs(u8s_16::load_simd(arch, py.add(16 * (i + 3)))),
)
};
let m3a = products(vx0, vy0);
let m3b = products(vx1, vy1);
acc = acc.dot_simd((m0a + m0b + m1a + m1b) + (m2a + m2b + m3a + m3b), ones);
i += 4;
}
while i < blocks {
let (vx0, vx1, (vy0, vy1)) = unsafe {
(
u8s_32::load_simd(arch, px.add(64 * i)),
u8s_32::load_simd(arch, px.add(64 * i + 32)),
unpack_crumbs(u8s_16::load_simd(arch, py.add(16 * i))),
)
};
acc = acc.dot_simd(products(vx0, vy0) + products(vx1, vy1), ones);
i += 1;
}
s = acc.sum_tree() as u32;
}
i *= 64;
if i != len {
#[inline(never)]
fn fallback(x: USlice<'_, 8>, y: USlice<'_, 2>, from: usize) -> u32 {
let mut s: u32 = 0;
for i in from..x.len() {
let (ix, iy) =
unsafe { (x.get_unchecked(i) as u32, y.get_unchecked(i) as u32) };
s += ix * iy;
}
s
}
s += fallback(x, y, i);
}
Ok(MV::new(s))
}
}
#[cfg(target_arch = "x86_64")]
impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<u32>, USlice<'_, 8>, USlice<'_, 1>>
for InnerProduct
{
#[inline(always)]
fn run(
self,
arch: diskann_wide::arch::x86_64::V3,
x: USlice<'_, 8>,
y: USlice<'_, 1>,
) -> MathematicalResult<u32> {
use diskann_wide::{FromInt, SIMDMask};
use std::arch::x86_64::_mm256_sad_epu8;
let len = check_lengths!(x, y)?;
diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
diskann_wide::alias!(u8s_32 = <diskann_wide::arch::x86_64::V3>::u8x32);
type Mask32 = diskann_wide::BitMask<32, diskann_wide::arch::x86_64::V3>;
type Mask8x32 = diskann_wide::arch::x86_64::v3::masks::mask8x32;
let px: *const u8 = x.as_ptr();
let py: *const u8 = y.as_ptr();
let mut i: usize = 0;
let mut s: u32 = 0;
let blocks = len / 32;
if blocks > 0 {
let mut acc = i32s::default(arch);
let zero = u8s_32::default(arch);
let masked_sad = |vx: u8s_32, bits: u32| -> i32s {
let byte_mask: Mask8x32 = Mask32::from_int(arch, bits).into();
let masked = vx & u8s_32::from_underlying(arch, byte_mask.to_underlying());
i32s::from_underlying(arch, unsafe {
_mm256_sad_epu8(masked.to_underlying(), zero.to_underlying())
})
};
while i + 4 <= blocks {
let s0 = unsafe {
let vx = u8s_32::load_simd(arch, px.add(32 * i));
let bits = (py.add(4 * i) as *const u32).read_unaligned();
masked_sad(vx, bits)
};
let s1 = unsafe {
let vx = u8s_32::load_simd(arch, px.add(32 * (i + 1)));
let bits = (py.add(4 * (i + 1)) as *const u32).read_unaligned();
masked_sad(vx, bits)
};
let s2 = unsafe {
let vx = u8s_32::load_simd(arch, px.add(32 * (i + 2)));
let bits = (py.add(4 * (i + 2)) as *const u32).read_unaligned();
masked_sad(vx, bits)
};
let s3 = unsafe {
let vx = u8s_32::load_simd(arch, px.add(32 * (i + 3)));
let bits = (py.add(4 * (i + 3)) as *const u32).read_unaligned();
masked_sad(vx, bits)
};
acc = acc + s0 + s1 + s2 + s3;
i += 4;
}
while i < blocks {
let si = unsafe {
let vx = u8s_32::load_simd(arch, px.add(32 * i));
let bits = (py.add(4 * i) as *const u32).read_unaligned();
masked_sad(vx, bits)
};
acc = acc + si;
i += 1;
}
s = acc.sum_tree() as u32;
}
i *= 32;
if i != len {
#[inline(never)]
fn fallback(x: USlice<'_, 8>, y: USlice<'_, 1>, from: usize) -> u32 {
let mut s: u32 = 0;
for i in from..x.len() {
let (ix, iy) =
unsafe { (x.get_unchecked(i) as u32, y.get_unchecked(i) as u32) };
s += ix * iy;
}
s
}
s += fallback(x, y, i);
}
Ok(MV::new(s))
}
}
#[cfg(target_arch = "aarch64")]
retarget!(
diskann_wide::arch::aarch64::Neon,
InnerProduct,
7,
6,
5,
4,
3,
2,
(8, 4),
(8, 2),
(8, 1)
);
impl<A> Target2<A, MathematicalResult<u32>, USlice<'_, 4, BitTranspose>, USlice<'_, 1, Dense>>
for InnerProduct
where
A: Architecture,
{
#[inline(always)]
fn run(
self,
_: A,
x: USlice<'_, 4, BitTranspose>,
y: USlice<'_, 1, Dense>,
) -> MathematicalResult<u32> {
let len = check_lengths!(x, y)?;
let px: *const u64 = x.as_ptr().cast();
let py: *const u64 = y.as_ptr().cast();
let mut i = 0;
let mut s: u32 = 0;
let blocks = len / 64;
while i < blocks {
let bits = unsafe { py.add(i).read_unaligned() };
let b0 = unsafe { px.add(4 * i).read_unaligned() };
s += (bits & b0).count_ones();
let b1 = unsafe { px.add(4 * i + 1).read_unaligned() };
s += (bits & b1).count_ones() << 1;
let b2 = unsafe { px.add(4 * i + 2).read_unaligned() };
s += (bits & b2).count_ones() << 2;
let b3 = unsafe { px.add(4 * i + 3).read_unaligned() };
s += (bits & b3).count_ones() << 3;
i += 1;
}
if 64 * i == len {
return Ok(MV::new(s));
}
let k = i * 8;
let py = unsafe { py.cast::<u8>().add(k) };
let bytes_remaining = y.bytes() - k;
let mut bits: u64 = 0;
for j in 0..bytes_remaining.min(8) {
bits += (unsafe { py.add(j).read() } as u64) << (8 * j);
}
bits &= (0x01u64 << (len - (64 * i))) - 1;
let b0 = unsafe { px.add(4 * i).read_unaligned() };
s += (bits & b0).count_ones();
let b1 = unsafe { px.add(4 * i + 1).read_unaligned() };
s += (bits & b1).count_ones() << 1;
let b2 = unsafe { px.add(4 * i + 2).read_unaligned() };
s += (bits & b2).count_ones() << 2;
let b3 = unsafe { px.add(4 * i + 3).read_unaligned() };
s += (bits & b3).count_ones() << 3;
Ok(MV::new(s))
}
}
impl
PureDistanceFunction<USlice<'_, 4, BitTranspose>, USlice<'_, 1, Dense>, MathematicalResult<u32>>
for InnerProduct
{
fn evaluate(
x: USlice<'_, 4, BitTranspose>,
y: USlice<'_, 1, Dense>,
) -> MathematicalResult<u32> {
(diskann_wide::ARCH).run2(Self, x, y)
}
}
#[cfg(target_arch = "x86_64")]
impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<f32>, &[f32], USlice<'_, 1>>
for InnerProduct
{
#[inline(always)]
fn run(
self,
arch: diskann_wide::arch::x86_64::V3,
x: &[f32],
y: USlice<'_, 1>,
) -> MathematicalResult<f32> {
let len = check_lengths!(x, y)?;
use std::arch::x86_64::*;
diskann_wide::alias!(f32s = <diskann_wide::arch::x86_64::V3>::f32x8);
diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
let values = f32s::from_array(arch, [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
let variable_shifts = u32s::from_array(arch, [0, 1, 2, 3, 4, 5, 6, 7]);
let px: *const f32 = x.as_ptr();
let py: *const u32 = y.as_ptr().cast();
let mut i = 0;
let mut s = f32s::default(arch);
let prep = |v: u32| -> u32s { u32s::splat(arch, v) >> variable_shifts };
let to_f32 = |v: u32s| -> f32s {
f32s::from_underlying(arch, unsafe {
_mm256_permutevar_ps(values.to_underlying(), v.to_underlying())
})
};
let blocks = len / 32;
if i < blocks {
let mut s0 = f32s::default(arch);
let mut s1 = f32s::default(arch);
while i < blocks {
let iy = prep(unsafe { py.add(i).read_unaligned() });
let ix0 = unsafe { f32s::load_simd(arch, px.add(32 * i)) };
let ix1 = unsafe { f32s::load_simd(arch, px.add(32 * i + 8)) };
let ix2 = unsafe { f32s::load_simd(arch, px.add(32 * i + 16)) };
let ix3 = unsafe { f32s::load_simd(arch, px.add(32 * i + 24)) };
s0 = ix0.mul_add_simd(to_f32(iy), s0);
s1 = ix1.mul_add_simd(to_f32(iy >> 8), s1);
s0 = ix2.mul_add_simd(to_f32(iy >> 16), s0);
s1 = ix3.mul_add_simd(to_f32(iy >> 24), s1);
i += 1;
}
s = s0 + s1;
}
let remainder = len % 32;
if remainder != 0 {
let tail = if len % 8 == 0 { 8 } else { len % 8 };
let py = unsafe { py.add(blocks) };
if remainder <= 8 {
unsafe {
load_one(py, |iy| {
let iy = prep(iy);
let ix = f32s::load_simd_first(arch, px.add(32 * blocks), tail);
s = ix.mul_add_simd(to_f32(iy), s);
})
}
} else if remainder <= 16 {
unsafe {
load_two(py, |iy| {
let iy = prep(iy);
let ix0 = f32s::load_simd(arch, px.add(32 * blocks));
let ix1 = f32s::load_simd_first(arch, px.add(32 * blocks + 8), tail);
s = ix0.mul_add_simd(to_f32(iy), s);
s = ix1.mul_add_simd(to_f32(iy >> 8), s);
})
}
} else if remainder <= 24 {
unsafe {
load_three(py, |iy| {
let iy = prep(iy);
let ix0 = f32s::load_simd(arch, px.add(32 * blocks));
let ix1 = f32s::load_simd(arch, px.add(32 * blocks + 8));
let ix2 = f32s::load_simd_first(arch, px.add(32 * blocks + 16), tail);
s = ix0.mul_add_simd(to_f32(iy), s);
s = ix1.mul_add_simd(to_f32(iy >> 8), s);
s = ix2.mul_add_simd(to_f32(iy >> 16), s);
})
}
} else {
unsafe {
load_four(py, |iy| {
let iy = prep(iy);
let ix0 = f32s::load_simd(arch, px.add(32 * blocks));
let ix1 = f32s::load_simd(arch, px.add(32 * blocks + 8));
let ix2 = f32s::load_simd(arch, px.add(32 * blocks + 16));
let ix3 = f32s::load_simd_first(arch, px.add(32 * blocks + 24), tail);
s = ix0.mul_add_simd(to_f32(iy), s);
s = ix1.mul_add_simd(to_f32(iy >> 8), s);
s = ix2.mul_add_simd(to_f32(iy >> 16), s);
s = ix3.mul_add_simd(to_f32(iy >> 24), s);
})
}
}
}
Ok(MV::new(s.sum_tree()))
}
}
#[cfg(target_arch = "x86_64")]
impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<f32>, &[f32], USlice<'_, 2>>
for InnerProduct
{
#[inline(always)]
fn run(
self,
arch: diskann_wide::arch::x86_64::V3,
x: &[f32],
y: USlice<'_, 2>,
) -> MathematicalResult<f32> {
let len = check_lengths!(x, y)?;
use std::arch::x86_64::*;
diskann_wide::alias!(f32s = <diskann_wide::arch::x86_64::V3>::f32x8);
diskann_wide::alias!(u32s = <diskann_wide::arch::x86_64::V3>::u32x8);
let values = f32s::from_array(arch, [0.0, 1.0, 2.0, 3.0, 0.0, 1.0, 2.0, 3.0]);
let variable_shifts = u32s::from_array(arch, [0, 2, 4, 6, 8, 10, 12, 14]);
let px: *const f32 = x.as_ptr();
let py: *const u32 = y.as_ptr().cast();
let mut i = 0;
let mut s = f32s::default(arch);
let prep = |v: u32| -> u32s { u32s::splat(arch, v) >> variable_shifts };
let to_f32 = |v: u32s| -> f32s {
f32s::from_underlying(arch, unsafe {
_mm256_permutevar_ps(values.to_underlying(), v.to_underlying())
})
};
let blocks = len / 16;
if blocks != 0 {
let mut s0 = f32s::default(arch);
let mut s1 = f32s::default(arch);
while i + 2 <= blocks {
let iy = prep(unsafe { py.add(i).read_unaligned() });
let (ix0, ix1) = unsafe {
(
f32s::load_simd(arch, px.add(16 * i)),
f32s::load_simd(arch, px.add(16 * i + 8)),
)
};
s0 = ix0.mul_add_simd(to_f32(iy), s0);
s1 = ix1.mul_add_simd(to_f32(iy >> 16), s1);
let iy = prep(unsafe { py.add(i + 1).read_unaligned() });
let (ix0, ix1) = unsafe {
(
f32s::load_simd(arch, px.add(16 * (i + 1))),
f32s::load_simd(arch, px.add(16 * (i + 1) + 8)),
)
};
s0 = ix0.mul_add_simd(to_f32(iy), s0);
s1 = ix1.mul_add_simd(to_f32(iy >> 16), s1);
i += 2;
}
if i < blocks {
let iy = prep(unsafe { py.add(i).read_unaligned() });
let (ix0, ix1) = unsafe {
(
f32s::load_simd(arch, px.add(16 * i)),
f32s::load_simd(arch, px.add(16 * i + 8)),
)
};
s0 = ix0.mul_add_simd(to_f32(iy), s0);
s1 = ix1.mul_add_simd(to_f32(iy >> 16), s1);
}
s = s0 + s1;
}
let remainder = len % 16;
if remainder != 0 {
let tail = if len % 8 == 0 { 8 } else { len % 8 };
let py = unsafe { py.add(blocks) };
if remainder <= 4 {
unsafe {
load_one(py, |iy| {
let iy = prep(iy);
let ix = f32s::load_simd_first(arch, px.add(16 * blocks), tail);
s = ix.mul_add_simd(to_f32(iy), s);
});
}
} else if remainder <= 8 {
unsafe {
load_two(py, |iy| {
let iy = prep(iy);
let ix = f32s::load_simd_first(arch, px.add(16 * blocks), tail);
s = ix.mul_add_simd(to_f32(iy), s);
});
}
} else if remainder <= 12 {
unsafe {
load_three(py, |iy| {
let iy = prep(iy);
let ix0 = f32s::load_simd(arch, px.add(16 * blocks));
let ix1 = f32s::load_simd_first(arch, px.add(16 * blocks + 8), tail);
s = ix0.mul_add_simd(to_f32(iy), s);
s = ix1.mul_add_simd(to_f32(iy >> 16), s);
});
}
} else {
unsafe {
load_four(py, |iy| {
let iy = prep(iy);
let ix0 = f32s::load_simd(arch, px.add(16 * blocks));
let ix1 = f32s::load_simd_first(arch, px.add(16 * blocks + 8), tail);
s = ix0.mul_add_simd(to_f32(iy), s);
s = ix1.mul_add_simd(to_f32(iy >> 16), s);
});
}
}
}
Ok(MV::new(s.sum_tree()))
}
}
#[cfg(target_arch = "x86_64")]
impl Target2<diskann_wide::arch::x86_64::V3, MathematicalResult<f32>, &[f32], USlice<'_, 4>>
for InnerProduct
{
#[inline(always)]
fn run(
self,
arch: diskann_wide::arch::x86_64::V3,
x: &[f32],
y: USlice<'_, 4>,
) -> MathematicalResult<f32> {
let len = check_lengths!(x, y)?;
diskann_wide::alias!(f32s = <diskann_wide::arch::x86_64::V3>::f32x8);
diskann_wide::alias!(i32s = <diskann_wide::arch::x86_64::V3>::i32x8);
let variable_shifts = i32s::from_array(arch, [0, 4, 8, 12, 16, 20, 24, 28]);
let mask = i32s::splat(arch, 0x0f);
let to_f32 = |v: u32| -> f32s {
((i32s::splat(arch, v as i32) >> variable_shifts) & mask).simd_cast()
};
let px: *const f32 = x.as_ptr();
let py: *const u32 = y.as_ptr().cast();
let mut i = 0;
let mut s = f32s::default(arch);
let blocks = len / 8;
while i < blocks {
let ix = unsafe { f32s::load_simd(arch, px.add(8 * i)) };
let iy = to_f32(unsafe { py.add(i).read_unaligned() });
s = ix.mul_add_simd(iy, s);
i += 1;
}
let remainder = len % 8;
if remainder != 0 {
let f = |iy| {
let ix = unsafe { f32s::load_simd_first(arch, px.add(8 * blocks), remainder) };
s = ix.mul_add_simd(to_f32(iy), s);
};
let py = unsafe { py.add(blocks) };
if remainder <= 2 {
unsafe { load_one(py, f) };
} else if remainder <= 4 {
unsafe { load_two(py, f) };
} else if remainder <= 6 {
unsafe { load_three(py, f) };
} else {
unsafe { load_four(py, f) };
}
}
Ok(MV::new(s.sum_tree()))
}
}
impl<const N: usize>
Target2<diskann_wide::arch::Scalar, MathematicalResult<f32>, &[f32], USlice<'_, N>>
for InnerProduct
where
Unsigned: Representation<N>,
{
#[inline(always)]
fn run(
self,
_: diskann_wide::arch::Scalar,
x: &[f32],
y: USlice<'_, N>,
) -> MathematicalResult<f32> {
check_lengths!(x, y)?;
let mut s = 0.0;
for (i, x) in x.iter().enumerate() {
let y = unsafe { y.get_unchecked(i) } as f32;
s += x * y;
}
Ok(MV::new(s))
}
}
macro_rules! ip_retarget {
($arch:path, $N:literal) => {
impl Target2<$arch, MathematicalResult<f32>, &[f32], USlice<'_, $N>>
for InnerProduct
{
#[inline(always)]
fn run(
self,
arch: $arch,
x: &[f32],
y: USlice<'_, $N>,
) -> MathematicalResult<f32> {
self.run(arch.retarget(), x, y)
}
}
};
($arch:path, $($Ns:literal),*) => {
$(ip_retarget!($arch, $Ns);)*
}
}
#[cfg(target_arch = "x86_64")]
ip_retarget!(diskann_wide::arch::x86_64::V3, 3, 5, 6, 7, 8);
#[cfg(target_arch = "x86_64")]
ip_retarget!(diskann_wide::arch::x86_64::V4, 1, 2, 3, 4, 5, 6, 7, 8);
#[cfg(target_arch = "aarch64")]
ip_retarget!(diskann_wide::arch::aarch64::Neon, 1, 2, 3, 4, 5, 6, 7, 8);
macro_rules! dispatch_full_ip {
($N:literal) => {
impl PureDistanceFunction<&[f32], USlice<'_, $N>, MathematicalResult<f32>>
for InnerProduct
{
fn evaluate(x: &[f32], y: USlice<'_, $N>) -> MathematicalResult<f32> {
Self.run(ARCH, x, y)
}
}
};
($($Ns:literal),*) => {
$(dispatch_full_ip!($Ns);)*
}
}
dispatch_full_ip!(1, 2, 3, 4, 5, 6, 7, 8);
#[cfg(test)]
mod tests {
use std::{collections::HashMap, fmt::Display, sync::LazyLock};
use diskann_utils::{Reborrow, lazy_format};
use rand::{
Rng, SeedableRng,
distr::{Distribution, Uniform},
rngs::StdRng,
seq::IndexedRandom,
};
use super::*;
use crate::bits::{BoxedBitSlice, Representation, Unsigned};
type MR = MathematicalResult<u32>;
#[inline(always)]
fn should_check_this_dimension(dim: usize) -> bool {
if cfg!(miri) {
return dim.is_power_of_two()
|| (dim > 1 && (dim - 1).is_power_of_two())
|| (dim < 64 && (dim % 8 == 7));
}
true
}
fn test_bitslice_distances<const NBITS: usize, R>(
dim_max: usize,
trials_per_dim: usize,
evaluate_l2: &dyn Fn(USlice<'_, NBITS>, USlice<'_, NBITS>) -> MR,
evaluate_ip: &dyn Fn(USlice<'_, NBITS>, USlice<'_, NBITS>) -> MR,
context: &str,
rng: &mut R,
) where
Unsigned: Representation<NBITS>,
R: Rng,
{
let domain = Unsigned::domain_const::<NBITS>();
let min: i64 = *domain.start();
let max: i64 = *domain.end();
let dist = Uniform::new_inclusive(min, max).unwrap();
for dim in 0..dim_max {
if !should_check_this_dimension(dim) {
continue;
}
let mut x_reference: Vec<u8> = vec![0; dim];
let mut y_reference: Vec<u8> = vec![0; dim];
let mut x = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(dim);
let mut y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(dim);
for trial in 0..trials_per_dim {
x_reference
.iter_mut()
.for_each(|i| *i = dist.sample(rng).try_into().unwrap());
y_reference
.iter_mut()
.for_each(|i| *i = dist.sample(rng).try_into().unwrap());
x.as_mut_slice().fill(u8::MAX);
y.as_mut_slice().fill(u8::MAX);
for i in 0..dim {
x.set(i, x_reference[i].into()).unwrap();
y.set(i, y_reference[i].into()).unwrap();
}
let expected: MV<f32> =
diskann_vector::distance::SquaredL2::evaluate(&*x_reference, &*y_reference);
let got = evaluate_l2(x.reborrow(), y.reborrow()).unwrap();
assert_eq!(
expected.into_inner(),
got.into_inner() as f32,
"failed SquaredL2 for NBITS = {}, dim = {}, trial = {} -- context {}",
NBITS,
dim,
trial,
context,
);
let expected: MV<f32> =
diskann_vector::distance::InnerProduct::evaluate(&*x_reference, &*y_reference);
let got = evaluate_ip(x.reborrow(), y.reborrow()).unwrap();
assert_eq!(
expected.into_inner(),
got.into_inner() as f32,
"faild InnerProduct for NBITS = {}, dim = {}, trial = {} -- context {}",
NBITS,
dim,
trial,
context,
);
}
}
let x = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(10);
let y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(11);
assert!(
evaluate_l2(x.reborrow(), y.reborrow()).is_err(),
"context: {}",
context
);
assert!(
evaluate_l2(y.reborrow(), x.reborrow()).is_err(),
"context: {}",
context
);
assert!(
evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
"context: {}",
context
);
assert!(
evaluate_ip(y.reborrow(), x.reborrow()).is_err(),
"context: {}",
context
);
}
cfg_if::cfg_if! {
if #[cfg(miri)] {
const MAX_DIM: usize = 132;
const TRIALS_PER_DIM: usize = 1;
} else {
const MAX_DIM: usize = 256;
const TRIALS_PER_DIM: usize = 20;
}
}
static BITSLICE_TEST_BOUNDS: LazyLock<HashMap<Key, Bounds>> = LazyLock::new(|| {
use ArchKey::{Neon, Scalar, X86_64_V3, X86_64_V4};
[
(Key::new(1, Scalar), Bounds::new(64, 64)),
(Key::new(1, X86_64_V3), Bounds::new(256, 256)),
(Key::new(1, X86_64_V4), Bounds::new(256, 256)),
(Key::new(1, Neon), Bounds::new(64, 64)),
(Key::new(2, Scalar), Bounds::new(64, 64)),
(Key::new(2, X86_64_V3), Bounds::new(512, 300)),
(Key::new(2, X86_64_V4), Bounds::new(768, 600)), (Key::new(2, Neon), Bounds::new(64, 64)),
(Key::new(3, Scalar), Bounds::new(64, 64)),
(Key::new(3, X86_64_V3), Bounds::new(256, 96)),
(Key::new(3, X86_64_V4), Bounds::new(256, 96)),
(Key::new(3, Neon), Bounds::new(64, 64)),
(Key::new(4, Scalar), Bounds::new(64, 64)),
(Key::new(4, X86_64_V3), Bounds::new(256, 150)),
(Key::new(4, X86_64_V4), Bounds::new(256, 150)),
(Key::new(4, Neon), Bounds::new(64, 64)),
(Key::new(5, Scalar), Bounds::new(64, 64)),
(Key::new(5, X86_64_V3), Bounds::new(256, 96)),
(Key::new(5, X86_64_V4), Bounds::new(256, 96)),
(Key::new(5, Neon), Bounds::new(64, 64)),
(Key::new(6, Scalar), Bounds::new(64, 64)),
(Key::new(6, X86_64_V3), Bounds::new(256, 96)),
(Key::new(6, X86_64_V4), Bounds::new(256, 96)),
(Key::new(6, Neon), Bounds::new(64, 64)),
(Key::new(7, Scalar), Bounds::new(64, 64)),
(Key::new(7, X86_64_V3), Bounds::new(256, 96)),
(Key::new(7, X86_64_V4), Bounds::new(256, 96)),
(Key::new(7, Neon), Bounds::new(64, 64)),
(Key::new(8, Scalar), Bounds::new(64, 64)),
(Key::new(8, X86_64_V3), Bounds::new(256, 96)),
(Key::new(8, X86_64_V4), Bounds::new(256, 96)),
(Key::new(8, Neon), Bounds::new(64, 64)),
]
.into_iter()
.collect()
});
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum ArchKey {
Scalar,
#[expect(non_camel_case_types)]
X86_64_V3,
#[expect(non_camel_case_types)]
X86_64_V4,
Neon,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct Key {
nbits: usize,
arch: ArchKey,
}
impl Key {
fn new(nbits: usize, arch: ArchKey) -> Self {
Self { nbits, arch }
}
}
#[derive(Debug, Clone, Copy)]
struct Bounds {
standard: usize,
miri: usize,
}
impl Bounds {
fn new(standard: usize, miri: usize) -> Self {
Self { standard, miri }
}
fn get(&self) -> usize {
if cfg!(miri) { self.miri } else { self.standard }
}
}
macro_rules! test_bitslice {
($name:ident, $nbits:literal, $seed:literal) => {
#[test]
fn $name() {
let mut rng = StdRng::seed_from_u64($seed);
let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::Scalar)].get();
test_bitslice_distances::<$nbits, _>(
max_dim,
TRIALS_PER_DIM,
&|x, y| SquaredL2::evaluate(x, y),
&|x, y| InnerProduct::evaluate(x, y),
"pure distance function",
&mut rng,
);
test_bitslice_distances::<$nbits, _>(
max_dim,
TRIALS_PER_DIM,
&|x, y| diskann_wide::arch::Scalar::new().run2(SquaredL2, x, y),
&|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
"scalar arch",
&mut rng,
);
#[cfg(target_arch = "x86_64")]
if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::X86_64_V3)].get();
test_bitslice_distances::<$nbits, _>(
max_dim,
TRIALS_PER_DIM,
&|x, y| arch.run2(SquaredL2, x, y),
&|x, y| arch.run2(InnerProduct, x, y),
"x86-64-v3",
&mut rng,
);
}
#[cfg(target_arch = "x86_64")]
if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::X86_64_V4)].get();
test_bitslice_distances::<$nbits, _>(
max_dim,
TRIALS_PER_DIM,
&|x, y| arch.run2(SquaredL2, x, y),
&|x, y| arch.run2(InnerProduct, x, y),
"x86-64-v4",
&mut rng,
);
}
#[cfg(target_arch = "aarch64")]
if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
let max_dim = BITSLICE_TEST_BOUNDS[&Key::new($nbits, ArchKey::Neon)].get();
test_bitslice_distances::<$nbits, _>(
max_dim,
TRIALS_PER_DIM,
&|x, y| arch.run2(SquaredL2, x, y),
&|x, y| arch.run2(InnerProduct, x, y),
"neon",
&mut rng,
);
}
}
};
}
test_bitslice!(test_bitslice_distances_8bit, 8, 0xf0330c6d880e08ff);
test_bitslice!(test_bitslice_distances_7bit, 7, 0x98aa7f2d4c83844f);
test_bitslice!(test_bitslice_distances_6bit, 6, 0xf2f7ad7a37764b4c);
test_bitslice!(test_bitslice_distances_5bit, 5, 0xae878d14973fb43f);
test_bitslice!(test_bitslice_distances_4bit, 4, 0x8d6dbb8a6b19a4f8);
test_bitslice!(test_bitslice_distances_3bit, 3, 0x8f56767236e58da2);
test_bitslice!(test_bitslice_distances_2bit, 2, 0xb04f741a257b61af);
test_bitslice!(test_bitslice_distances_1bit, 1, 0x820ea031c379eab5);
fn test_hamming_distances<R>(dim_max: usize, trials_per_dim: usize, rng: &mut R)
where
R: Rng,
{
let dist: [i8; 2] = [-1, 1];
for dim in 0..dim_max {
if !should_check_this_dimension(dim) {
continue;
}
let mut x_reference: Vec<i8> = vec![1; dim];
let mut y_reference: Vec<i8> = vec![1; dim];
let mut x = BoxedBitSlice::<1, Binary>::new_boxed(dim);
let mut y = BoxedBitSlice::<1, Binary>::new_boxed(dim);
for _ in 0..trials_per_dim {
x_reference
.iter_mut()
.for_each(|i| *i = *dist.choose(rng).unwrap());
y_reference
.iter_mut()
.for_each(|i| *i = *dist.choose(rng).unwrap());
x.as_mut_slice().fill(u8::MAX);
y.as_mut_slice().fill(u8::MAX);
for i in 0..dim {
x.set(i, x_reference[i].into()).unwrap();
y.set(i, y_reference[i].into()).unwrap();
}
let expected: MV<f32> =
diskann_vector::distance::SquaredL2::evaluate(&*x_reference, &*y_reference);
let got: MV<u32> = Hamming::evaluate(x.reborrow(), y.reborrow()).unwrap();
assert_eq!(4.0 * (got.into_inner() as f32), expected.into_inner());
}
}
let x = BoxedBitSlice::<1, Binary>::new_boxed(10);
let y = BoxedBitSlice::<1, Binary>::new_boxed(11);
assert!(Hamming::evaluate(x.reborrow(), y.reborrow()).is_err());
assert!(Hamming::evaluate(y.reborrow(), x.reborrow()).is_err());
}
#[test]
fn test_hamming_distance() {
let mut rng = StdRng::seed_from_u64(0x2160419161246d97);
test_hamming_distances(MAX_DIM, TRIALS_PER_DIM, &mut rng);
}
fn test_bit_transpose_distances<R>(
dim_max: usize,
trials_per_dim: usize,
evaluate_ip: &dyn Fn(USlice<'_, 4, BitTranspose>, USlice<'_, 1>) -> MR,
context: &str,
rng: &mut R,
) where
R: Rng,
{
let dist_4bit = {
let domain = Unsigned::domain_const::<4>();
Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap()
};
let dist_1bit = {
let domain = Unsigned::domain_const::<1>();
Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap()
};
for dim in 0..dim_max {
if !should_check_this_dimension(dim) {
continue;
}
let mut x_reference: Vec<u8> = vec![0; dim];
let mut y_reference: Vec<u8> = vec![0; dim];
let mut x = BoxedBitSlice::<4, Unsigned, BitTranspose>::new_boxed(dim);
let mut y = BoxedBitSlice::<1, Unsigned, Dense>::new_boxed(dim);
for trial in 0..trials_per_dim {
x_reference
.iter_mut()
.for_each(|i| *i = dist_4bit.sample(rng).try_into().unwrap());
y_reference
.iter_mut()
.for_each(|i| *i = dist_1bit.sample(rng).try_into().unwrap());
x.as_mut_slice().fill(u8::MAX);
y.as_mut_slice().fill(u8::MAX);
for i in 0..dim {
x.set(i, x_reference[i].into()).unwrap();
y.set(i, y_reference[i].into()).unwrap();
}
let expected: MV<f32> =
diskann_vector::distance::InnerProduct::evaluate(&*x_reference, &*y_reference);
let got = evaluate_ip(x.reborrow(), y.reborrow());
assert_eq!(
expected.into_inner(),
got.unwrap().into_inner() as f32,
"faild InnerProduct for dim = {}, trial = {} -- context {}",
dim,
trial,
context,
);
}
}
let x = BoxedBitSlice::<4, Unsigned, BitTranspose>::new_boxed(10);
let y = BoxedBitSlice::<1, Unsigned>::new_boxed(11);
assert!(
evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
"context: {}",
context
);
let y = BoxedBitSlice::<1, Unsigned>::new_boxed(9);
assert!(
evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
"context: {}",
context
);
}
#[test]
fn test_bit_transpose_distance() {
let mut rng = StdRng::seed_from_u64(0xe20e26e926d4b853);
test_bit_transpose_distances(
MAX_DIM,
TRIALS_PER_DIM,
&|x, y| InnerProduct::evaluate(x, y),
"pure distance function",
&mut rng,
);
test_bit_transpose_distances(
MAX_DIM,
TRIALS_PER_DIM,
&|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
"scalar",
&mut rng,
);
#[cfg(target_arch = "x86_64")]
if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
test_bit_transpose_distances(
MAX_DIM,
TRIALS_PER_DIM,
&|x, y| arch.run2(InnerProduct, x, y),
"x86-64-v3",
&mut rng,
);
}
#[cfg(target_arch = "x86_64")]
if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
test_bit_transpose_distances(
MAX_DIM,
TRIALS_PER_DIM,
&|x, y| arch.run2(InnerProduct, x, y),
"x86-64-v4",
&mut rng,
);
}
#[cfg(target_arch = "aarch64")]
if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
test_bit_transpose_distances(
MAX_DIM,
TRIALS_PER_DIM,
&|x, y| arch.run2(InnerProduct, x, y),
"neon",
&mut rng,
);
}
}
fn test_full_distances<const NBITS: usize>(
dim_max: usize,
trials_per_dim: usize,
evaluate_ip: &dyn Fn(&[f32], USlice<'_, NBITS>) -> MathematicalResult<f32>,
context: &str,
rng: &mut impl Rng,
) where
Unsigned: Representation<NBITS>,
{
let dist_float = [-2.0, -1.0, 0.0, 1.0, 2.0];
let dist_bit = {
let domain = Unsigned::domain_const::<NBITS>();
Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap()
};
for dim in 0..dim_max {
if !should_check_this_dimension(dim) {
continue;
}
let mut x: Vec<f32> = vec![0.0; dim];
let mut y_reference: Vec<u8> = vec![0; dim];
let mut y = BoxedBitSlice::<NBITS, Unsigned, Dense>::new_boxed(dim);
for trial in 0..trials_per_dim {
x.iter_mut()
.for_each(|i| *i = *dist_float.choose(rng).unwrap());
y_reference
.iter_mut()
.for_each(|i| *i = dist_bit.sample(rng).try_into().unwrap());
y.as_mut_slice().fill(u8::MAX);
let mut expected = 0.0;
for i in 0..dim {
y.set(i, y_reference[i].into()).unwrap();
expected += y_reference[i] as f32 * x[i];
}
let got = evaluate_ip(&x, y.reborrow()).unwrap();
assert_eq!(
expected,
got.into_inner(),
"faild InnerProduct for dim = {}, trial = {} -- context {}",
dim,
trial,
context,
);
let scalar: MV<f32> = InnerProduct
.run(diskann_wide::arch::Scalar, x.as_slice(), y.reborrow())
.unwrap();
assert_eq!(got.into_inner(), scalar.into_inner());
}
}
let x = vec![0.0; 10];
let y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(11);
assert!(
evaluate_ip(x.as_slice(), y.reborrow()).is_err(),
"context: {}",
context
);
let y = BoxedBitSlice::<NBITS, Unsigned>::new_boxed(9);
assert!(
evaluate_ip(x.as_slice(), y.reborrow()).is_err(),
"context: {}",
context
);
}
macro_rules! test_full {
($name:ident, $nbits:literal, $seed:literal) => {
#[test]
fn $name() {
let mut rng = StdRng::seed_from_u64($seed);
test_full_distances::<$nbits>(
MAX_DIM,
TRIALS_PER_DIM,
&|x, y| InnerProduct::evaluate(x, y),
"pure distance function",
&mut rng,
);
test_full_distances::<$nbits>(
MAX_DIM,
TRIALS_PER_DIM,
&|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
"scalar",
&mut rng,
);
#[cfg(target_arch = "x86_64")]
if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
test_full_distances::<$nbits>(
MAX_DIM,
TRIALS_PER_DIM,
&|x, y| arch.run2(InnerProduct, x, y),
"x86-64-v3",
&mut rng,
);
}
#[cfg(target_arch = "x86_64")]
if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
test_full_distances::<$nbits>(
MAX_DIM,
TRIALS_PER_DIM,
&|x, y| arch.run2(InnerProduct, x, y),
"x86-64-v4",
&mut rng,
);
}
#[cfg(target_arch = "aarch64")]
if let Some(arch) = diskann_wide::arch::aarch64::Neon::new_checked() {
test_full_distances::<$nbits>(
MAX_DIM,
TRIALS_PER_DIM,
&|x, y| arch.run2(InnerProduct, x, y),
"neon",
&mut rng,
);
}
}
};
}
test_full!(test_full_distance_1bit, 1, 0xe20e26e926d4b853);
test_full!(test_full_distance_2bit, 2, 0xae9542700aecbf68);
test_full!(test_full_distance_3bit, 3, 0xfffd04b26bb6068c);
test_full!(test_full_distance_4bit, 4, 0x86db49fd1a1704ba);
test_full!(test_full_distance_5bit, 5, 0x3a35dc7fa7931c41);
test_full!(test_full_distance_6bit, 6, 0x1f69de79e418d336);
test_full!(test_full_distance_7bit, 7, 0x3fcf17b82dadc5ab);
test_full!(test_full_distance_8bit, 8, 0x85dcaf48b1399db2);
struct HetCase<const M: usize> {
x_vals: Vec<i64>,
y_vals: Vec<i64>,
}
impl<const M: usize> HetCase<M>
where
Unsigned: Representation<M>,
{
fn new(dim: usize, fill: impl FnMut(usize) -> (i64, i64)) -> Self {
let (x_vals, y_vals) = (0..dim).map(fill).unzip();
Self { x_vals, y_vals }
}
fn check_with(
&self,
label: impl Display,
evaluate: &dyn Fn(USlice<'_, 8>, USlice<'_, M>) -> MR,
) {
let dim = self.x_vals.len();
let mut x = BoxedBitSlice::<8, Unsigned>::new_boxed(dim);
let mut y = BoxedBitSlice::<M, Unsigned>::new_boxed(dim);
x.as_mut_slice().fill(u8::MAX);
y.as_mut_slice().fill(u8::MAX);
for (i, (&xv, &yv)) in self.x_vals.iter().zip(&self.y_vals).enumerate() {
x.set(i, xv).unwrap();
y.set(i, yv).unwrap();
}
let expected: u32 = self
.x_vals
.iter()
.zip(&self.y_vals)
.map(|(&a, &b)| a as u32 * b as u32)
.sum();
let got = evaluate(x.reborrow(), y.reborrow()).unwrap().into_inner();
assert_eq!(expected, got, "{} failed for dim = {}", label, dim);
}
}
fn fuzz_heterogeneous_ip<const M: usize>(
dim_max: usize,
trials_per_dim: usize,
max_val: i64,
evaluate_ip: &dyn Fn(USlice<'_, 8>, USlice<'_, M>) -> MR,
context: &str,
rng: &mut impl Rng,
) where
Unsigned: Representation<M>,
{
let dist_8bit = Uniform::new_inclusive(0i64, 255i64).unwrap();
let dist_mbit = Uniform::new_inclusive(0i64, max_val).unwrap();
for dim in 0..dim_max {
for trial in 0..trials_per_dim {
HetCase::<M>::new(dim, |_| {
(dist_8bit.sample(&mut *rng), dist_mbit.sample(&mut *rng))
})
.check_with(
lazy_format!("IP(8,{}) dim={dim}, trial={trial} -- {context}", M),
evaluate_ip,
);
}
let x = BoxedBitSlice::<8, Unsigned>::new_boxed(dim);
let y = BoxedBitSlice::<M, Unsigned>::new_boxed(dim + 1);
assert!(
evaluate_ip(x.reborrow(), y.reborrow()).is_err(),
"context: {}",
context,
);
}
}
fn het_test_max_values<const M: usize>(
max_val: i64,
context: &str,
evaluate: &dyn Fn(USlice<'_, 8>, USlice<'_, M>) -> MR,
) where
Unsigned: Representation<M>,
{
let dims = [127, 128, 129, 255, 256, 512, 768, 896, 3072];
for &dim in &dims {
let case = HetCase::<M>::new(dim, |_| (255, max_val));
case.check_with(lazy_format!("max-value {context} dim={dim}"), evaluate);
}
}
fn het_test_known_answers<const M: usize>(
max_val: i64,
evaluate: &dyn Fn(USlice<'_, 8>, USlice<'_, M>) -> MR,
) where
Unsigned: Representation<M>,
{
HetCase::<M>::new(64, |_| (200, max_val)).check_with("vpmaddubsw operand-order", evaluate);
let y_val = (max_val / 2).max(1);
HetCase::<M>::new(128, |i| ((i % 256) as i64, y_val))
.check_with("ascending-x constant-y", evaluate);
HetCase::<M>::new(1, |_| (200, max_val)).check_with("single element", evaluate);
}
fn het_test_edge_cases<const M: usize>(
max_val: i64,
block_size: usize,
evaluate: &dyn Fn(USlice<'_, 8>, USlice<'_, M>) -> MR,
) where
Unsigned: Representation<M>,
{
let y_half = (max_val / 2).max(1);
HetCase::<M>::new(64, |_| (0, max_val)).check_with("x-zero y-nonzero", evaluate);
HetCase::<M>::new(64, |_| (255, 0)).check_with("y-zero x-nonzero", evaluate);
for dim in 0..=(block_size + 1) {
HetCase::<M>::new(dim, |_| (3, y_half)).check_with("uniform fill", evaluate);
}
for &dim in &[block_size, 2 * block_size, 4 * block_size, 8 * block_size] {
HetCase::<M>::new(dim, |_| (100, max_val)).check_with("exact block boundary", evaluate);
}
HetCase::<M>::new(300, |i| ((i % 256) as i64, 1))
.check_with("x-varies y-constant", evaluate);
HetCase::<M>::new(300, |i| (1, (i as i64) % (max_val + 1)))
.check_with("x-constant y-varies", evaluate);
HetCase::<M>::new(128, |i| if i % 2 == 0 { (255, max_val) } else { (0, 0) })
.check_with("alternating pattern", evaluate);
HetCase::<M>::new(128, |i| if i % 2 == 0 { (0, 0) } else { (255, max_val) })
.check_with("opposite alternating", evaluate);
HetCase::<M>::new(1024, |_| (255, max_val)).check_with("large accumulation", evaluate);
for x_val in [128i64, 170, 200, 240, 255] {
HetCase::<M>::new(block_size, move |_| (x_val, y_half))
.check_with(lazy_format!("x > 127 (x_val={x_val})"), evaluate);
}
HetCase::<M>::new(block_size - 1, |i| {
(
((i * 7 + 3) % 256) as i64,
((i * 11 + 5) as i64) % (max_val + 1),
)
})
.check_with("dim=block_size-1 (all scalar)", evaluate);
let unroll4 = 4 * block_size;
for &dim in &[
unroll4,
unroll4 + 1,
unroll4 + block_size,
unroll4 + block_size + 1,
] {
HetCase::<M>::new(dim, |i| {
(((i + 1) % 256) as i64, ((i + 1) as i64) % (max_val + 1))
})
.check_with("unroll boundary", evaluate);
}
}
macro_rules! heterogeneous_ip_tests_8xM {
(
mod_name: $mod:ident,
M: $M:literal,
max_val: $max_val:literal,
block_size: $block_size:literal,
seed_fuzz: $seed_fuzz:literal,
) => {
mod $mod {
use super::*;
#[test]
fn all_ip_dispatches() {
let mut rng = StdRng::seed_from_u64($seed_fuzz);
fuzz_heterogeneous_ip::<$M>(
MAX_DIM,
TRIALS_PER_DIM,
$max_val,
&|x, y| InnerProduct::evaluate(x, y),
"pure distance function",
&mut rng,
);
fuzz_heterogeneous_ip::<$M>(
MAX_DIM,
TRIALS_PER_DIM,
$max_val,
&|x, y| diskann_wide::arch::Scalar::new().run2(InnerProduct, x, y),
"scalar arch",
&mut rng,
);
#[cfg(target_arch = "x86_64")]
if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
fuzz_heterogeneous_ip::<$M>(
MAX_DIM,
TRIALS_PER_DIM,
$max_val,
&|x, y| arch.run2(InnerProduct, x, y),
"x86-64-v3",
&mut rng,
);
}
#[cfg(target_arch = "x86_64")]
if let Some(arch) = diskann_wide::arch::x86_64::V4::new_checked_miri() {
fuzz_heterogeneous_ip::<$M>(
MAX_DIM,
TRIALS_PER_DIM,
$max_val,
&|x, y| arch.run2(InnerProduct, x, y),
"x86-64-v4",
&mut rng,
);
}
}
#[test]
fn max_values() {
het_test_max_values::<$M>($max_val, "dispatch", &|x, y| {
InnerProduct::evaluate(x, y)
});
#[cfg(target_arch = "x86_64")]
if let Some(arch) = diskann_wide::arch::x86_64::V3::new_checked() {
het_test_max_values::<$M>($max_val, "V3", &|x, y| {
arch.run2(InnerProduct, x, y)
});
}
}
#[test]
fn known_answers() {
het_test_known_answers::<$M>($max_val, &|x, y| InnerProduct::evaluate(x, y));
}
#[test]
fn edge_cases() {
het_test_edge_cases::<$M>($max_val, $block_size, &|x, y| {
InnerProduct::evaluate(x, y)
});
}
}
};
}
heterogeneous_ip_tests_8xM! {
mod_name: heterogeneous_ip_8x4,
M: 4,
max_val: 15,
block_size: 32,
seed_fuzz: 0xd3a7f1c09b2e4856,
}
heterogeneous_ip_tests_8xM! {
mod_name: heterogeneous_ip_8x2,
M: 2,
max_val: 3,
block_size: 64,
seed_fuzz: 0x82c4a6e809f1d3b5,
}
heterogeneous_ip_tests_8xM! {
mod_name: heterogeneous_ip_8x1,
M: 1,
max_val: 1,
block_size: 32,
seed_fuzz: 0x1b17_a5e7c2d0f839,
}
}