use rand::Rng;
use rand_distr::{Distribution, Normal};
use crate::Half;
pub(crate) fn no_vector_compare_f16_as_f64(a: &[Half], b: &[Half]) -> f64 {
let mut sum: f64 = 0.0;
debug_assert_eq!(a.len(), b.len());
for i in 0..a.len() {
sum += (a[i].to_f32() as f64 - b[i].to_f32() as f64).powi(2);
}
sum
}
pub(crate) fn no_vector_compare_f32_as_f64(a: &[f32], b: &[f32]) -> f64 {
let mut sum: f64 = 0.0;
debug_assert_eq!(a.len(), b.len());
for i in 0..a.len() {
sum += (a[i] as f64 - b[i] as f64).powi(2);
}
sum
}
pub(crate) trait GenerateRandomArguments<T> {
fn generate<R: Rng>(&self, rng: &mut R, dim: usize) -> Vec<T>;
}
impl GenerateRandomArguments<f32> for Normal<f32> {
fn generate<R: Rng>(&self, rng: &mut R, dim: usize) -> Vec<f32> {
(0..dim).map(|_| self.sample(rng)).collect()
}
}
impl GenerateRandomArguments<Half> for Normal<f32> {
fn generate<R: Rng>(&self, rng: &mut R, dim: usize) -> Vec<Half> {
(0..dim)
.map(|_| diskann_wide::cast_f32_to_f16(self.sample(rng)))
.collect()
}
}
impl GenerateRandomArguments<i8> for rand::distr::StandardUniform {
fn generate<R: Rng>(&self, rng: &mut R, dim: usize) -> Vec<i8> {
(0..dim).map(|_| self.sample(rng)).collect()
}
}
impl GenerateRandomArguments<u8> for rand::distr::StandardUniform {
fn generate<R: Rng>(&self, rng: &mut R, dim: usize) -> Vec<u8> {
(0..dim).map(|_| self.sample(rng)).collect()
}
}
pub(crate) trait Normalize {
fn normalize(&mut self);
}
impl Normalize for [f32] {
fn normalize(&mut self) {
let norm = self.iter().map(|x| (*x) * (*x)).sum::<f32>().sqrt();
if norm == 0.0 {
return;
}
self.iter_mut().for_each(|x| *x /= norm);
}
}
impl Normalize for [Half] {
fn normalize(&mut self) {
let mut copy: Vec<f32> = self.iter().map(|&i| i.into()).collect();
copy.normalize();
for (s, c) in std::iter::zip(self.iter_mut(), copy.iter()) {
*s = diskann_wide::cast_f32_to_f16(*c);
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct Normalized<T>(pub(crate) T);
impl GenerateRandomArguments<f32> for Normalized<Normal<f32>> {
fn generate<R: Rng>(&self, rng: &mut R, dim: usize) -> Vec<f32> {
let mut v = self.0.generate(rng, dim);
v.normalize();
v
}
}
impl GenerateRandomArguments<Half> for Normalized<Normal<f32>> {
fn generate<R: Rng>(&self, rng: &mut R, dim: usize) -> Vec<Half> {
let mut v = self.0.generate(rng, dim);
v.normalize();
v
}
}
pub(crate) trait CornerCases: Sized + Copy {
fn corner_cases() -> Vec<Self>;
}
impl CornerCases for f32 {
fn corner_cases() -> Vec<Self> {
vec![0.0, -5.0, 5.0, 10.0]
}
}
impl CornerCases for Half {
fn corner_cases() -> Vec<Self> {
f32::corner_cases()
.iter()
.map(|x| diskann_wide::cast_f32_to_f16(*x))
.collect()
}
}
impl CornerCases for i8 {
fn corner_cases() -> Vec<Self> {
vec![i8::MIN, i8::MAX, 0]
}
}
impl CornerCases for u8 {
fn corner_cases() -> Vec<Self> {
vec![u8::MIN, u8::MAX, 0]
}
}
pub(crate) trait DistanceChecker<Left, Right> {
fn check(&mut self, left: &[Left], right: &[Right]);
}
type BoxedFn<'a, Left, Right, To> = Box<dyn FnMut(&[Left], &[Right]) -> To + 'a>;
pub(crate) struct Checker<'a, Left, Right, To = f32> {
under_test: BoxedFn<'a, Left, Right, To>,
reference: BoxedFn<'a, Left, Right, To>,
compare: Box<dyn FnMut(To, To) + 'a>,
}
impl<'a, Left, Right, To> Checker<'a, Left, Right, To> {
pub(crate) fn new<L, R, C>(under_test: L, reference: R, compare: C) -> Self
where
L: FnMut(&[Left], &[Right]) -> To + 'a,
R: FnMut(&[Left], &[Right]) -> To + 'a,
C: FnMut(To, To) + 'a,
{
Self {
under_test: Box::new(under_test),
reference: Box::new(reference),
compare: Box::new(compare),
}
}
}
impl<Left, Right, To> DistanceChecker<Left, Right> for Checker<'_, Left, Right, To> {
fn check(&mut self, left: &[Left], right: &[Right]) {
(self.compare)(
(self.under_test)(left, right),
(self.reference)(left, right),
);
}
}
pub(crate) struct AdHocChecker<'a, Left, Right>(BoxedFn<'a, Left, Right, ()>);
impl<'a, Left, Right> AdHocChecker<'a, Left, Right> {
pub(crate) fn new<C>(f: C) -> Self
where
C: FnMut(&[Left], &[Right]) + 'a,
{
Self(Box::new(f))
}
}
impl<Left, Right> DistanceChecker<Left, Right> for AdHocChecker<'_, Left, Right> {
fn check(&mut self, left: &[Left], right: &[Right]) {
(self.0)(left, right)
}
}
pub(crate) fn test_distance_function<Left, Right, Check, LeftDist, RightDist, R>(
mut checker: Check,
left_dist: LeftDist,
right_dist: RightDist,
dim: usize,
trials: usize,
rng: &mut R,
) where
Check: DistanceChecker<Left, Right>,
Left: CornerCases,
Right: CornerCases,
LeftDist: GenerateRandomArguments<Left>,
RightDist: GenerateRandomArguments<Right>,
R: Rng,
{
for vleft in Left::corner_cases() {
for vright in Right::corner_cases() {
let left = vec![vleft; dim];
let right = vec![vright; dim];
checker.check(&left, &right);
}
}
for _ in 0..trials {
let left = left_dist.generate(rng, dim);
let right = right_dist.generate(rng, dim);
checker.check(&left, &right);
}
}
#[cfg(test)]
mod test_test_utils {
use rand::{Rng, SeedableRng};
use super::*;
fn test_generation_and_check_results<T, Dist, R, Checker>(
distribution: &Dist,
rng: &mut R,
max_dim: usize,
mut checker: Checker,
) where
R: Rng,
Dist: GenerateRandomArguments<T>,
Checker: FnMut(&T),
{
for dim in 0..=max_dim {
let v = distribution.generate(rng, dim);
assert_eq!(v.len(), dim);
v.iter().for_each(&mut checker);
}
}
#[test]
fn test_i8_generation() {
let mut seen: std::collections::HashSet<i8> = std::collections::HashSet::new();
let mut rng = rand::rngs::StdRng::seed_from_u64(0x078912AF);
let distribution = rand::distr::StandardUniform {};
test_generation_and_check_results(&distribution, &mut rng, 256, |i: &i8| {
seen.insert(*i);
});
assert_eq!(seen.len(), 256);
}
#[test]
fn test_u8_generation() {
let mut seen: std::collections::HashSet<u8> = std::collections::HashSet::new();
let mut rng = rand::rngs::StdRng::seed_from_u64(0xdef053c);
let distribution = rand::distr::StandardUniform {};
test_generation_and_check_results(&distribution, &mut rng, 256, |i: &u8| {
seen.insert(*i);
});
assert_eq!(seen.len(), 256);
}
fn test_float_generation<T>(seed: u64)
where
rand_distr::Normal<f32>: GenerateRandomArguments<T>,
T: Copy + Into<f32>,
{
let mut low: f32 = f32::MAX;
let mut high: f32 = f32::MIN;
let mut count_inside: u64 = 0;
let mut total_count: u64 = 0;
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let mean = 0.0;
let std = 2.0;
let distribution = rand_distr::Normal::new(mean, std).unwrap();
test_generation_and_check_results(&distribution, &mut rng, 256, |x: &T| {
let x: f32 = (*x).into();
low = low.min(x);
high = high.max(x);
total_count += 1;
if (x - mean).abs() <= std {
count_inside += 1;
}
});
assert!((count_inside as f64) / (total_count as f64) >= 0.65);
assert!(high >= mean + 3.0 * std);
assert!(low <= mean - 3.0 * std);
}
#[test]
fn test_f32_generation() {
test_float_generation::<f32>(0x132435);
}
#[test]
fn test_f16_generation() {
test_float_generation::<Half>(0x978675);
}
fn simple_inner_product_f32(x: &[f32], y: &[f32]) -> f32 {
std::iter::zip(x.iter(), y.iter()).map(|(a, b)| a * b).sum()
}
#[test]
fn test_test_distance_function() {
let mut under_test_count = 0;
let mut reference_count = 0;
let mut check_count = 0;
let dim = 10;
let trials = 100;
let checker = Checker::<f32, f32, f32>::new(
|left, right| {
assert!(left.len() == dim);
assert!(right.len() == dim);
under_test_count += 1;
simple_inner_product_f32(left, right) + 1.0
},
|left, right| {
reference_count += 1;
simple_inner_product_f32(left, right)
},
|a: f32, b: f32| {
check_count += 1;
assert_eq!(a, b + 1.0);
},
);
let mut rng = rand::rngs::StdRng::seed_from_u64(5);
test_distance_function(
checker,
rand_distr::Normal::new(0.0, 1.0).unwrap(),
rand_distr::Normal::new(0.0, 1.0).unwrap(),
dim,
trials,
&mut rng,
);
let left_cases = f32::corner_cases().len();
let right_cases = f32::corner_cases().len();
let expected_corner_cases = left_cases * right_cases;
let total_expected = expected_corner_cases + trials;
assert_eq!(under_test_count, total_expected);
assert_eq!(reference_count, total_expected);
assert_eq!(check_count, total_expected);
}
#[test]
#[should_panic]
fn test_error_propagation() {
let checker = AdHocChecker::<u8, u8>::new(|_, _| panic!("panic"));
let mut rng = rand::rngs::StdRng::seed_from_u64(64);
test_distance_function(
checker,
rand::distr::StandardUniform {},
rand::distr::StandardUniform {},
5,
10,
&mut rng,
)
}
}