use std::fmt;
use diskann_utils::strided;
use diskann_wide::{SIMDMask, SIMDMulAdd, SIMDPartialOrd, SIMDSelect, SIMDVector};
use crate::{
algorithms::kmeans,
distances::{InnerProduct, SquaredL2},
multi_vector::BlockTransposed,
};
diskann_wide::alias!(f32s = f32x8);
diskann_wide::alias!(u32s = u32x8);
#[derive(Debug, Clone)]
pub enum ChunkConstructionError {
DimensionCannotBeZero,
LengthCannotBeZero,
}
impl fmt::Display for ChunkConstructionError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ChunkConstructionError::DimensionCannotBeZero => write!(
f,
"cannot construct a Chunk from a source with zero dimensions"
)?,
ChunkConstructionError::LengthCannotBeZero => {
write!(f, "cannot construct a Chunk from a source with zero length")?
}
}
Ok(())
}
}
impl std::error::Error for ChunkConstructionError {}
#[derive(Debug, Clone, Copy)]
pub struct CompressionResult(u32);
impl CompressionResult {
fn err() -> Self {
Self(u32::MAX)
}
pub(super) fn into_inner(self) -> u32 {
self.0
}
pub fn is_okay(&self) -> bool {
self.0 != u32::MAX
}
pub fn map<F, G, R, E>(self, ok: F, err: G) -> Result<R, E>
where
F: FnOnce(u32) -> R,
G: FnOnce() -> E,
{
if self.is_okay() {
Ok(ok(self.0))
} else {
Err(err())
}
}
#[cfg(test)]
pub(crate) fn unwrap(self) -> u32 {
assert!(self.is_okay());
self.0
}
}
#[derive(Debug)]
pub struct Chunk {
data: BlockTransposed<f32, 16>,
square_norms: Vec<f32>,
}
impl Chunk {
const fn groupsize() -> usize {
BlockTransposed::<f32, 16>::const_group_size()
}
pub(super) const fn batchsize() -> usize {
4
}
pub(super) fn dimension(&self) -> usize {
self.data.ncols()
}
pub(super) fn num_centers(&self) -> usize {
self.data.nrows()
}
pub(super) fn num_blocks(&self) -> usize {
self.data.num_blocks()
}
pub(super) fn remainder(&self) -> usize {
self.data.remainder()
}
#[cfg(test)]
pub(super) fn get(&self, row: usize, col: usize) -> f32 {
assert!(
row < self.num_centers(),
"row {} must be less than {}",
row,
self.num_centers()
);
assert!(
col < self.dimension(),
"col {} must be less than {}",
col,
self.dimension()
);
self.data[(row, col)]
}
pub(super) fn new(data: strided::StridedView<'_, f32>) -> Result<Self, ChunkConstructionError> {
if data.ncols() == 0 {
return Err(ChunkConstructionError::DimensionCannotBeZero);
}
if data.nrows() == 0 {
return Err(ChunkConstructionError::LengthCannotBeZero);
}
let square_norms = data.row_iter().map(kmeans::square_norm).collect();
let data = BlockTransposed::<f32, 16>::from_strided(data);
Ok(Self { data, square_norms })
}
fn full_blocks(&self) -> usize {
self.data.full_blocks()
}
pub(super) fn find_closest<T>(&self, x: &[T]) -> CompressionResult
where
T: Copy + Into<f32>,
{
assert_eq!(x.len(), self.dimension(), "incorrect query dimension");
let mut min_distances: f32s = f32s::splat(diskann_wide::ARCH, f32::INFINITY);
let mut min_indices = u32s::splat(diskann_wide::ARCH, u32::MAX);
let index_offsets = u32s::from_array(diskann_wide::ARCH, [0, 1, 2, 3, 4, 5, 6, 7]);
let full_blocks = self.full_blocks();
for block in 0..full_blocks {
let (d0, d1) = self.compute_in_block::<InnerProductMathematical, T>(x, block);
let (norm0, norm1) = unsafe { self.load_norms(block) };
let d0 = norm0 - (d0 + d0);
let d1 = norm1 - (d1 + d1);
(min_distances, min_indices) = update_tracking_with(
min_distances,
min_indices,
d0,
(Self::groupsize() * block) as u32,
index_offsets,
);
(min_distances, min_indices) = update_tracking_with(
min_distances,
min_indices,
d1,
(Self::groupsize() * block + f32s::LANES) as u32,
index_offsets,
)
}
let remainder = self.remainder();
if remainder != 0 {
let (d0, d1) = self.compute_in_remainder::<InnerProductMathematical, T>(x);
let (norm0, norm1) = unsafe { self.load_remainder_norms() };
let d0 = norm0 - (d0 + d0);
let d1 = norm1 - (d1 + d1);
(min_distances, min_indices) = update_tracking_with(
min_distances,
min_indices,
d0,
(Self::groupsize() * full_blocks) as u32,
index_offsets,
);
(min_distances, min_indices) = update_tracking_with(
min_distances,
min_indices,
d1,
(Self::groupsize() * full_blocks + f32s::LANES) as u32,
index_offsets,
);
}
let mut minimum_distance = f32::MAX;
let mut minimum_index = u32::MAX;
for (&i, &d) in std::iter::zip(
min_indices.to_array().iter(),
min_distances.to_array().iter(),
) {
if d < minimum_distance {
minimum_distance = d;
minimum_index = i;
}
}
if minimum_distance.is_finite() {
CompressionResult(minimum_index)
} else {
CompressionResult::err()
}
}
pub(super) fn find_closest_batch<T>(
&self,
x: strided::StridedView<'_, T>,
) -> [CompressionResult; Self::batchsize()]
where
T: Copy + Into<f32>,
{
assert_eq!(
x.nrows(),
Self::batchsize(),
"argument StridedView must have a length of {}",
Self::batchsize()
);
assert_eq!(x.ncols(), self.dimension(), "incorrect query dimension");
let dim = self.dimension();
let mut tracking: [(f32s, u32s); Self::batchsize()] = [
(
f32s::splat(diskann_wide::ARCH, f32::INFINITY),
u32s::splat(diskann_wide::ARCH, u32::MAX),
),
(
f32s::splat(diskann_wide::ARCH, f32::INFINITY),
u32s::splat(diskann_wide::ARCH, u32::MAX),
),
(
f32s::splat(diskann_wide::ARCH, f32::INFINITY),
u32s::splat(diskann_wide::ARCH, u32::MAX),
),
(
f32s::splat(diskann_wide::ARCH, f32::INFINITY),
u32s::splat(diskann_wide::ARCH, u32::MAX),
),
];
let index_offsets = u32s::from_array(diskann_wide::ARCH, [0, 1, 2, 3, 4, 5, 6, 7]);
let unsafe_load_x = |k| {
debug_assert!(k < x.ncols());
unsafe {
debug_assert!(3 < x.nrows());
(
f32s::splat(
diskann_wide::ARCH,
<T as Into<f32>>::into(*x.get_row_unchecked(0).get_unchecked(k)),
),
f32s::splat(
diskann_wide::ARCH,
<T as Into<f32>>::into(*x.get_row_unchecked(1).get_unchecked(k)),
),
f32s::splat(
diskann_wide::ARCH,
<T as Into<f32>>::into(*x.get_row_unchecked(2).get_unchecked(k)),
),
f32s::splat(
diskann_wide::ARCH,
<T as Into<f32>>::into(*x.get_row_unchecked(3).get_unchecked(k)),
),
)
}
};
let remainder = self.remainder();
let base_ptr = self.data.as_ptr();
const STRIDE: usize = Chunk::groupsize();
let unsafe_load_pivots = |block, dim_offset| -> (f32s, f32s) {
debug_assert!(block < self.num_blocks());
debug_assert!(dim_offset < self.dimension());
let index = STRIDE * (block * dim + dim_offset);
unsafe {
(
f32s::load_simd(diskann_wide::ARCH, base_ptr.add(index)),
f32s::load_simd(diskann_wide::ARCH, base_ptr.add(index + f32s::LANES)),
)
}
};
for i in 0..self.num_blocks() {
let mut d0_0 = f32s::default(diskann_wide::ARCH);
let mut d0_1 = f32s::default(diskann_wide::ARCH);
let mut d1_0 = f32s::default(diskann_wide::ARCH);
let mut d1_1 = f32s::default(diskann_wide::ARCH);
let mut d2_0 = f32s::default(diskann_wide::ARCH);
let mut d2_1 = f32s::default(diskann_wide::ARCH);
let mut d3_0 = f32s::default(diskann_wide::ARCH);
let mut d3_1 = f32s::default(diskann_wide::ARCH);
const INNER_UNROLL: usize = 2;
let unrolled_iterations = dim / INNER_UNROLL;
for j in 0..unrolled_iterations {
let j_linear = INNER_UNROLL * j;
let (a0, a1, a2, a3) = unsafe_load_x(j_linear);
let (b0, b1) = unsafe_load_pivots(i, j_linear);
d0_0 = a0.mul_add_simd(b0, d0_0);
d0_1 = a0.mul_add_simd(b1, d0_1);
d1_0 = a1.mul_add_simd(b0, d1_0);
d1_1 = a1.mul_add_simd(b1, d1_1);
d2_0 = a2.mul_add_simd(b0, d2_0);
d2_1 = a2.mul_add_simd(b1, d2_1);
d3_0 = a3.mul_add_simd(b0, d3_0);
d3_1 = a3.mul_add_simd(b1, d3_1);
let (a0, a1, a2, a3) = unsafe_load_x(j_linear + 1);
let (b0, b1) = unsafe_load_pivots(i, j_linear + 1);
d0_0 = a0.mul_add_simd(b0, d0_0);
d0_1 = a0.mul_add_simd(b1, d0_1);
d1_0 = a1.mul_add_simd(b0, d1_0);
d1_1 = a1.mul_add_simd(b1, d1_1);
d2_0 = a2.mul_add_simd(b0, d2_0);
d2_1 = a2.mul_add_simd(b1, d2_1);
d3_0 = a3.mul_add_simd(b0, d3_0);
d3_1 = a3.mul_add_simd(b1, d3_1);
}
let last_unrolled = INNER_UNROLL * unrolled_iterations;
if last_unrolled != dim {
debug_assert!(last_unrolled + 1 == dim);
let (a0, a1, a2, a3) = unsafe_load_x(last_unrolled);
let (b0, b1) = unsafe_load_pivots(i, last_unrolled);
d0_0 = a0.mul_add_simd(b0, d0_0);
d0_1 = a0.mul_add_simd(b1, d0_1);
d1_0 = a1.mul_add_simd(b0, d1_0);
d1_1 = a1.mul_add_simd(b1, d1_1);
d2_0 = a2.mul_add_simd(b0, d2_0);
d2_1 = a2.mul_add_simd(b1, d2_1);
d3_0 = a3.mul_add_simd(b0, d3_0);
d3_1 = a3.mul_add_simd(b1, d3_1);
}
let (norm_0, norm_1) = if remainder != 0 && i + 1 == self.num_blocks() {
let infinity = f32s::splat(diskann_wide::ARCH, f32::NEG_INFINITY);
let lo = remainder.min(f32s::LANES);
let hi = remainder - lo;
let mask_lo = <f32s as SIMDVector>::Mask::keep_first(diskann_wide::ARCH, lo);
d0_0 = mask_lo.select(d0_0, infinity);
d1_0 = mask_lo.select(d1_0, infinity);
d2_0 = mask_lo.select(d2_0, infinity);
d3_0 = mask_lo.select(d3_0, infinity);
let mask_hi = <f32s as SIMDVector>::Mask::keep_first(diskann_wide::ARCH, hi);
d0_1 = mask_hi.select(d0_1, infinity);
d1_1 = mask_hi.select(d1_1, infinity);
d2_1 = mask_hi.select(d2_1, infinity);
d3_1 = mask_hi.select(d3_1, infinity);
unsafe { self.load_remainder_norms() }
} else {
unsafe { self.load_norms(i) }
};
let two = f32s::splat(diskann_wide::ARCH, 2.0f32);
d0_0 = norm_0 - two * d0_0;
d0_1 = norm_1 - two * d0_1;
d1_0 = norm_0 - two * d1_0;
d1_1 = norm_1 - two * d1_1;
d2_0 = norm_0 - two * d2_0;
d2_1 = norm_1 - two * d2_1;
d3_0 = norm_0 - two * d3_0;
d3_1 = norm_1 - two * d3_1;
let ind_0 =
u32s::splat(diskann_wide::ARCH, (Self::groupsize() * i) as u32) + index_offsets;
let ind_1 = u32s::splat(
diskann_wide::ARCH,
(Self::groupsize() * i + f32s::LANES) as u32,
) + index_offsets;
tracking[0] =
update_tracking(tracking[0], update_tracking((d0_0, ind_0), (d0_1, ind_1)));
tracking[1] =
update_tracking(tracking[1], update_tracking((d1_0, ind_0), (d1_1, ind_1)));
tracking[2] =
update_tracking(tracking[2], update_tracking((d2_0, ind_0), (d2_1, ind_1)));
tracking[3] =
update_tracking(tracking[3], update_tracking((d3_0, ind_0), (d3_1, ind_1)));
}
let finish = |(distances, indices): (f32s, u32s)| {
let mut minimum_distance = f32::INFINITY;
let mut minimum_index = u32::MAX;
for (&i, &d) in std::iter::zip(indices.to_array().iter(), distances.to_array().iter()) {
if d < minimum_distance {
minimum_distance = d;
minimum_index = i;
}
}
if minimum_distance.is_finite() {
CompressionResult(minimum_index)
} else {
CompressionResult::err()
}
};
[
finish(tracking[0]),
finish(tracking[1]),
finish(tracking[2]),
finish(tracking[3]),
]
}
#[inline(always)]
unsafe fn load_norms(&self, block: usize) -> (f32s, f32s) {
debug_assert!(block < self.full_blocks());
let ptr = unsafe { self.square_norms.as_ptr().add(Chunk::groupsize() * block) };
unsafe {
(
f32s::load_simd(diskann_wide::ARCH, ptr),
f32s::load_simd(diskann_wide::ARCH, ptr.add(f32s::LANES)),
)
}
}
#[inline(always)]
unsafe fn load_remainder_norms(&self) -> (f32s, f32s) {
let remainder = self.remainder();
debug_assert!(remainder != 0);
let first = remainder % f32s::LANES;
let ptr = unsafe {
self.square_norms
.as_ptr()
.add(Chunk::groupsize() * self.full_blocks())
};
if remainder < f32s::LANES {
unsafe {
(
f32s::load_simd_first(diskann_wide::ARCH, ptr, first),
f32s::default(diskann_wide::ARCH),
)
}
} else {
unsafe {
(
f32s::load_simd(diskann_wide::ARCH, ptr),
f32s::load_simd_first(diskann_wide::ARCH, ptr.add(f32s::LANES), first),
)
}
}
}
fn compute_in_block<Op, T>(&self, x: &[T], block: usize) -> (f32s, f32s)
where
Op: ComputeKernel,
T: Copy + Into<f32>,
{
assert_eq!(x.len(), self.dimension());
assert!(block < self.data.num_blocks());
let ptr = unsafe { self.data.block_ptr_unchecked(block) };
let acc = (
f32s::default(diskann_wide::ARCH),
f32s::default(diskann_wide::ARCH),
);
x.iter().enumerate().fold(acc, |acc, (i, x)| {
let a0 =
unsafe { f32s::load_simd(diskann_wide::ARCH, ptr.add(Chunk::groupsize() * i)) };
let a1 = unsafe {
f32s::load_simd(
diskann_wide::ARCH,
ptr.add(Chunk::groupsize() * i + f32s::LANES),
)
};
Op::step(*x, (a0, a1), acc)
})
}
fn compute_in_remainder<Op, T>(&self, x: &[T]) -> (f32s, f32s)
where
Op: ComputeKernel,
T: Copy + Into<f32>,
{
let d = self.compute_in_block::<Op, T>(x, self.data.full_blocks());
let remainder = self.remainder();
let keep = <f32s as SIMDVector>::Mask::keep_first(diskann_wide::ARCH, remainder % 8);
let padding = f32s::splat(diskann_wide::ARCH, Op::REMAINDER);
if remainder < f32s::LANES {
(keep.select(d.0, padding), padding)
} else {
(d.0, keep.select(d.1, padding))
}
}
}
#[inline(always)]
fn update_tracking_with(
min_distances: f32s,
min_indices: u32s,
distances: f32s,
base_index: u32,
index_offsets: u32s,
) -> (f32s, u32s) {
update_tracking(
(min_distances, min_indices),
(
distances,
u32s::splat(diskann_wide::ARCH, base_index) + index_offsets,
),
)
}
#[inline(always)]
fn update_tracking((d0, i0): (f32s, u32s), (d1, i1): (f32s, u32s)) -> (f32s, u32s) {
let mask = d1.lt_simd(d0);
(
mask.select(d1, d0),
<u32s as SIMDVector>::Mask::from(mask).select(i1, i0),
)
}
trait ComputeKernel {
const REMAINDER: f32;
fn step<T>(x: T, y: (f32s, f32s), accumulator: (f32s, f32s)) -> (f32s, f32s)
where
T: Into<f32>;
}
impl ComputeKernel for SquaredL2 {
const REMAINDER: f32 = f32::INFINITY;
#[inline(always)]
fn step<T>(x: T, y: (f32s, f32s), accumulator: (f32s, f32s)) -> (f32s, f32s)
where
T: Into<f32>,
{
let b = f32s::splat(diskann_wide::ARCH, x.into());
let d0 = y.0 - b;
let d1 = y.1 - b;
(
d0.mul_add_simd(d0, accumulator.0),
d1.mul_add_simd(d1, accumulator.1),
)
}
}
impl ComputeKernel for InnerProduct {
const REMAINDER: f32 = f32::INFINITY;
#[inline(always)]
fn step<T>(x: T, y: (f32s, f32s), accumulator: (f32s, f32s)) -> (f32s, f32s)
where
T: Into<f32>,
{
let x: f32 = x.into();
let b = f32s::splat(diskann_wide::ARCH, -x);
(
b.mul_add_simd(y.0, accumulator.0),
b.mul_add_simd(y.1, accumulator.1),
)
}
}
struct InnerProductMathematical;
impl ComputeKernel for InnerProductMathematical {
const REMAINDER: f32 = f32::NEG_INFINITY;
#[inline(always)]
fn step<T>(x: T, y: (f32s, f32s), accumulator: (f32s, f32s)) -> (f32s, f32s)
where
T: Into<f32>,
{
let x: f32 = x.into();
let b = f32s::splat(diskann_wide::ARCH, x);
(
b.mul_add_simd(y.0, accumulator.0),
b.mul_add_simd(y.1, accumulator.1),
)
}
}
pub trait ProcessInto {
fn process_into(chunk: &Chunk, from: &[f32], into: &mut [f32]);
}
impl<T> ProcessInto for T
where
T: ComputeKernel,
{
fn process_into(chunk: &Chunk, from: &[f32], into: &mut [f32]) {
assert_eq!(from.len(), chunk.dimension());
assert_eq!(into.len(), chunk.num_centers());
let ptr = into.as_mut_ptr();
let full_blocks = chunk.full_blocks();
let remainder = chunk.remainder();
for block in 0..chunk.num_blocks() {
let (lo, hi) = chunk.compute_in_block::<T, f32>(from, block);
if remainder != 0 && block == full_blocks {
let keep_lo = remainder.min(f32s::LANES);
let keep_hi = remainder - keep_lo;
unsafe { lo.store_simd_first(ptr.add(Chunk::groupsize() * full_blocks), keep_lo) };
if keep_hi != 0 {
unsafe {
hi.store_simd_first(
ptr.add(Chunk::groupsize() * full_blocks + f32s::LANES),
keep_hi,
)
};
}
} else {
unsafe {
lo.store_simd(ptr.add(Chunk::groupsize() * block));
hi.store_simd(ptr.add(Chunk::groupsize() * block + f32s::LANES));
}
}
}
}
}
#[cfg(test)]
mod tests {
use diskann_utils::{lazy_format, views};
use diskann_vector::{PureDistanceFunction, distance};
use rand::{
SeedableRng,
distr::{Distribution, Uniform},
rngs::StdRng,
};
use super::*;
#[test]
fn compression_result() {
let v = CompressionResult::err();
assert_eq!(v.into_inner(), u32::MAX);
assert!(!v.is_okay());
for i in 0u32..1000u32 {
let v = CompressionResult(i);
assert!(v.is_okay());
assert_eq!(v.unwrap(), i);
assert_eq!(v.into_inner(), i);
}
{
let mut called_ok = false;
let mut called_err = false;
let x: Result<&str, &str> = CompressionResult(10).map(
|v| {
called_ok = true;
assert_eq!(v, 10);
"okay!"
},
|| {
called_err = true;
"not okay!"
},
);
assert_eq!(x.unwrap(), "okay!");
assert!(called_ok);
assert!(!called_err);
}
{
let mut called_ok = false;
let mut called_err = false;
let x: Result<&str, &str> = CompressionResult::err().map(
|_| {
called_ok = true;
"okay!"
},
|| {
called_err = true;
"not okay!"
},
);
assert_eq!(x.unwrap_err(), "not okay!");
assert!(!called_ok);
assert!(called_err);
}
}
#[test]
#[should_panic]
fn compression_result_unwrap_panics() {
CompressionResult::err().unwrap();
}
fn flatten(x: &[Vec<f32>]) -> Vec<f32> {
assert!(!x.is_empty());
let dim = x[0].len();
assert!(x.iter().all(|i| i.len() == dim));
let mut output = Vec::new();
x.iter().for_each(|i| {
output.extend_from_slice(i.as_slice());
});
assert_eq!(output.len(), dim * x.len());
output
}
fn create_test_pattern(dim: usize, total: usize) -> Vec<Vec<f32>> {
(0..total)
.map(|i| (0..dim).map(|j| (i + j) as f32).collect())
.collect()
}
fn test_batch(
chunk: &Chunk,
query: &[f32],
expected_closest: usize,
zero_matches: usize,
test_context: &dyn std::fmt::Display,
) {
let dim = query.len();
for j in 0..Chunk::batchsize() {
let copy_query = |k| {
if k == j {
query.to_vec()
} else {
vec![0.0; dim]
}
};
assert_eq!(
4,
Chunk::batchsize(),
"if the lower level batch size changes, update the function below"
);
let query_batch =
flatten(&[copy_query(0), copy_query(1), copy_query(2), copy_query(3)]);
let view = strided::StridedView::try_from(
query_batch.as_slice(),
Chunk::batchsize(),
dim,
dim,
)
.unwrap();
assert_eq!(view.nrows(), Chunk::batchsize());
assert_eq!(view.ncols(), dim);
for k in 0..view.nrows() {
let row = view.row(k);
if k == j {
assert_eq!(
row, query,
"expected entry {k} to be the query ({test_context})"
);
} else {
assert!(
row.iter().all(|&k| k == 0.0),
"expected inactive rows to be zero ({test_context})"
);
}
}
let closest: [CompressionResult; Chunk::batchsize()] = chunk.find_closest_batch(view);
for (k, &got) in closest.iter().enumerate() {
let got = got.unwrap() as usize;
if k == j {
assert_eq!(
got, expected_closest,
"failed to match active lane {k} ({test_context})"
);
} else {
assert_eq!(
got, zero_matches,
"inactive lane {k} assigned to the wrong center ({test_context})"
);
}
}
let maybe_broadcast = |k, v: f32| {
if k == j { vec![v; dim] } else { query.to_vec() }
};
let values = [f32::INFINITY, f32::NEG_INFINITY, f32::NAN, f32::INFINITY];
let query_batch = flatten(&[
maybe_broadcast(0, values[0]),
maybe_broadcast(1, values[1]),
maybe_broadcast(2, values[2]),
maybe_broadcast(3, values[3]),
]);
let view = strided::StridedView::try_from(
query_batch.as_slice(),
Chunk::batchsize(),
dim,
dim,
)
.unwrap();
let closest = chunk.find_closest_batch(view);
for (k, &got) in closest.iter().enumerate() {
if k == j {
assert!(
!got.is_okay(),
"lane {} should not be okay with value {} ({})",
k,
values[k],
test_context
);
} else {
assert_eq!(
got.unwrap() as usize,
expected_closest,
"failed to match active lane {k} ({test_context})"
);
}
}
}
}
fn test_chunk(dim: usize, total: usize) {
let test_context = lazy_format!("ndims {}, total {}", dim, total);
let mut data_aggregate = create_test_pattern(dim, total);
let data = flatten(&data_aggregate);
let sliced = strided::StridedView::try_from(data.as_slice(), total, dim, dim).unwrap();
let chunk = Chunk::new(sliced).unwrap();
assert_eq!(chunk.num_centers(), total);
assert_eq!(chunk.dimension(), dim);
for row in 0..sliced.nrows() {
for col in 0..sliced.ncols() {
assert_eq!(
sliced[(row, col)],
chunk.get(row, col),
"failed on row {} and col {}",
row,
col,
);
}
}
assert!(!chunk.find_closest(&vec![f32::NEG_INFINITY; dim]).is_okay());
assert!(!chunk.find_closest(&vec![f32::INFINITY; dim]).is_okay());
assert!(!chunk.find_closest(&vec![f32::NAN; dim]).is_okay());
let query: Vec<f32> = vec![0.0; dim];
assert_eq!(chunk.find_closest(&query).unwrap(), 0);
for i in 0..total {
let query: Vec<f32> = (0..dim).map(|j| ((i + j) as f32) + 0.125).collect();
let closest = chunk.find_closest(&query).unwrap();
assert_eq!(closest as usize, i);
test_batch(
&chunk,
query.as_slice(),
i,
0,
&lazy_format!("main iteration {}, {}", i, test_context),
);
if cfg!(miri) {
return;
}
}
for i in total..=(total + dim) {
let query: Vec<f32> = vec![i as f32; dim];
let closest = chunk.find_closest(&query).unwrap();
assert!((closest as usize) < chunk.num_centers());
test_batch(
&chunk,
query.as_slice(),
closest as usize,
0,
&lazy_format!("tail matching {}, {}", i, test_context),
);
}
let last = data_aggregate.last().unwrap().clone();
data_aggregate[0].clone_from(&last);
let data = flatten(&data_aggregate);
let sliced = strided::StridedView::try_from(data.as_slice(), total, dim, dim).unwrap();
let chunk = Chunk::new(sliced).unwrap();
assert_eq!(chunk.num_centers(), total);
assert_eq!(chunk.dimension(), dim);
assert_eq!(
chunk.find_closest(&last).unwrap(),
0,
"ties must resolve to lower index, {}",
test_context
);
let zero_matches = if chunk.num_centers() <= 2 { 0 } else { 1 };
test_batch(
&chunk,
&last,
0,
zero_matches,
&lazy_format!("ties resolve to first, {}", test_context),
);
}
#[test]
fn run_test_happy_path() {
let dims: Vec<usize> = if cfg!(miri) {
(7..=8).collect()
} else {
(1..=16).collect()
};
let totals: Vec<usize> = if cfg!(miri) {
vec![1, 2, 7, 8, 9, 15, 16, 17, 31, 32, 33]
} else {
[
(1..=17), (64..=103), (255..=257), ]
.into_iter()
.flatten()
.collect()
};
for &total in totals.iter() {
for &dim in dims.iter() {
println!("on {}, {}", dim, total);
test_chunk(dim, total);
}
}
}
#[test]
fn test_chunk_construction_error() {
let chunk = Chunk::new(strided::StridedView::try_from(&[], 3, 0, 0).unwrap());
let err = chunk.unwrap_err();
assert!(
err.to_string()
.contains("cannot construct a Chunk from a source with zero dimensions")
);
let chunk = Chunk::new(strided::StridedView::try_from(&[], 0, 10, 10).unwrap());
let err = chunk.unwrap_err();
assert!(
err.to_string()
.contains("cannot construct a Chunk from a source with zero length")
);
}
#[test]
#[should_panic(expected = "incorrect query dimension")]
fn test_find_closest_panics() {
let dim = 10;
let total = 13;
let data = flatten(&create_test_pattern(dim, total));
let sliced = strided::StridedView::try_from(data.as_slice(), total, dim, dim).unwrap();
let chunk = Chunk::new(sliced).unwrap();
let query: Vec<f32> = vec![0.0; total];
chunk.find_closest(query.as_slice());
}
#[test]
#[should_panic(expected = "incorrect query dimension")]
fn test_find_closest_batch_panics_on_dim_mismatch() {
let dim = 10;
let total = 13;
let data = flatten(&create_test_pattern(dim, total));
let sliced = strided::StridedView::try_from(data.as_slice(), total, dim, dim).unwrap();
let chunk = Chunk::new(sliced).unwrap();
let query: Vec<f32> = vec![0.0; 4 * total];
let query_view =
strided::StridedView::try_from(query.as_slice(), Chunk::batchsize(), total, total)
.unwrap();
chunk.find_closest_batch(query_view);
}
#[test]
#[should_panic(expected = "argument StridedView must have a length of")]
fn test_find_closest_batch_panics_on_non_batch_length() {
let dim = 10;
let total = 13;
let data = flatten(&create_test_pattern(dim, total));
let sliced = strided::StridedView::try_from(data.as_slice(), total, dim, dim).unwrap();
let chunk = Chunk::new(sliced).unwrap();
let query: Vec<f32> = vec![0.0; (Chunk::batchsize() + 1) * dim];
let query_view =
strided::StridedView::try_from(query.as_slice(), Chunk::batchsize() + 1, dim, dim)
.unwrap();
chunk.find_closest_batch(query_view);
}
#[test]
#[should_panic(expected = "row 5 must be less than 5")]
fn get_panics_on_row() {
let data = views::Matrix::new(0.0, 5, 10);
let chunk = Chunk::new(data.as_view().into()).unwrap();
chunk.get(5, 1);
}
#[test]
#[should_panic(expected = "col 5 must be less than 5")]
fn get_panics_on_col() {
let data = views::Matrix::new(0.0, 10, 5);
let chunk = Chunk::new(data.as_view().into()).unwrap();
chunk.get(1, 5);
}
cfg_if::cfg_if! {
if #[cfg(miri)] {
const PROCESS_INTO_TRIALS: usize = 1;
} else {
const PROCESS_INTO_TRIALS: usize = 10;
}
}
fn test_process_into_impl(dim: usize, total: usize, rng: &mut StdRng) {
let distribution = Uniform::<i32>::new(-10, 10).unwrap();
let base =
views::Matrix::<f32>::new(views::Init(|| distribution.sample(rng) as f32), total, dim);
let chunk = Chunk::new(base.as_view().into()).unwrap();
let mut input = vec![0.0; dim];
let mut output = vec![0.0; total];
for _ in 0..PROCESS_INTO_TRIALS {
input
.iter_mut()
.for_each(|i| *i = distribution.sample(rng) as f32);
InnerProduct::process_into(&chunk, &input, &mut output);
std::iter::zip(base.row_iter(), output.iter()).for_each(|(row, got)| {
let expected: f32 = distance::InnerProduct::evaluate(row, input.as_slice());
assert_eq!(*got, expected);
});
SquaredL2::process_into(&chunk, &input, &mut output);
std::iter::zip(base.row_iter(), output.iter()).for_each(|(row, got)| {
let expected: f32 = distance::SquaredL2::evaluate(row, input.as_slice());
assert_eq!(*got, expected);
});
}
}
#[test]
fn test_process_into() {
let mut rng = StdRng::seed_from_u64(0x21dfb5f35dfe5639);
let total_range = if cfg!(miri) { 1..48 } else { 1..64 };
let dim_range = if cfg!(miri) { 4..5 } else { 1..5 };
for total in total_range {
for dim in dim_range.clone() {
println!("on ({}, {})", total, dim);
test_process_into_impl(dim, total, &mut rng);
}
}
}
#[test]
#[should_panic]
fn test_process_into_panics_on_from() {
let data = views::Matrix::<f32>::new(0.0, 5, 10);
let chunk = Chunk::new(data.as_view().into()).unwrap();
assert_eq!(chunk.dimension(), 10);
assert_eq!(chunk.num_centers(), 5);
let query: Vec<f32> = vec![0.0; chunk.dimension() + 1];
let mut dst = vec![0.0; chunk.num_centers()];
InnerProduct::process_into(&chunk, query.as_slice(), dst.as_mut_slice());
}
#[test]
#[should_panic]
fn test_process_into_panics_on_into() {
let data = views::Matrix::<f32>::new(0.0, 5, 10);
let chunk = Chunk::new(data.as_view().into()).unwrap();
assert_eq!(chunk.dimension(), 10);
assert_eq!(chunk.num_centers(), 5);
let query: Vec<f32> = vec![0.0; chunk.dimension()];
let mut dst = vec![0.0; chunk.num_centers() + 1];
InnerProduct::process_into(&chunk, query.as_slice(), dst.as_mut_slice());
}
}